diff --git a/eng/Packages.Data.props b/eng/Packages.Data.props index af7d96b3d056..c0ca1524ec60 100644 --- a/eng/Packages.Data.props +++ b/eng/Packages.Data.props @@ -103,6 +103,7 @@ + @@ -345,6 +346,8 @@ + + diff --git a/sdk/core/Azure.Core/src/DiagnosticsOptions.cs b/sdk/core/Azure.Core/src/DiagnosticsOptions.cs index bfec3b23b5da..6fe5f5b55375 100644 --- a/sdk/core/Azure.Core/src/DiagnosticsOptions.cs +++ b/sdk/core/Azure.Core/src/DiagnosticsOptions.cs @@ -41,6 +41,8 @@ internal DiagnosticsOptions(DiagnosticsOptions? diagnosticsOptions) } else { + // These values are similar to the default values in System.ClientModel.Primitives.ClientLoggingOptions and both + // should be kept in sync. When updating, update the default values in both classes. LoggedHeaderNames = new List() { "x-ms-request-id", diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs index 675d7c513786..2cce8ba3d9f0 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.net6.0.cs @@ -83,6 +83,19 @@ public enum ClientErrorBehaviors Default = 0, NoThrow = 1, } + public partial class ClientLoggingOptions + { + public ClientLoggingOptions() { } + public System.Collections.Generic.IList AllowedHeaderNames { get { throw null; } } + public System.Collections.Generic.IList AllowedQueryParameters { get { throw null; } } + public bool? EnableLogging { get { throw null; } set { } } + public bool? EnableMessageContentLogging { get { throw null; } set { } } + public bool? EnableMessageLogging { get { throw null; } set { } } + public Microsoft.Extensions.Logging.ILoggerFactory? LoggerFactory { get { throw null; } set { } } + public int? MessageContentSizeLimit { get { throw null; } set { } } + protected void AssertNotFrozen() { } + public virtual void Freeze() { } + } public sealed partial class ClientPipeline { internal ClientPipeline() { } @@ -95,6 +108,8 @@ public void Send(System.ClientModel.Primitives.PipelineMessage message) { } public partial class ClientPipelineOptions { public ClientPipelineOptions() { } + public System.ClientModel.Primitives.ClientLoggingOptions? ClientLoggingOptions { get { throw null; } set { } } + public System.ClientModel.Primitives.PipelinePolicy? MessageLoggingPolicy { get { throw null; } set { } } public System.TimeSpan? NetworkTimeout { get { throw null; } set { } } public System.ClientModel.Primitives.PipelinePolicy? RetryPolicy { get { throw null; } set { } } public System.ClientModel.Primitives.PipelineTransport? Transport { get { throw null; } set { } } @@ -105,6 +120,7 @@ public virtual void Freeze() { } public partial class ClientRetryPolicy : System.ClientModel.Primitives.PipelinePolicy { public ClientRetryPolicy(int maxRetries = 3) { } + public ClientRetryPolicy(int maxRetries, bool enableLogging, Microsoft.Extensions.Logging.ILoggerFactory? loggerFactory) { } public static System.ClientModel.Primitives.ClientRetryPolicy Default { get { throw null; } } protected virtual System.TimeSpan GetNextDelay(System.ClientModel.Primitives.PipelineMessage message, int tryCount) { throw null; } protected virtual void OnRequestSent(System.ClientModel.Primitives.PipelineMessage message) { } @@ -129,6 +145,7 @@ public partial class HttpClientPipelineTransport : System.ClientModel.Primitives { public HttpClientPipelineTransport() { } public HttpClientPipelineTransport(System.Net.Http.HttpClient client) { } + public HttpClientPipelineTransport(System.Net.Http.HttpClient? client, bool enableLogging, Microsoft.Extensions.Logging.ILoggerFactory? loggerFactory) { } public static System.ClientModel.Primitives.HttpClientPipelineTransport Shared { get { throw null; } } protected override System.ClientModel.Primitives.PipelineMessage CreateMessageCore() { throw null; } public void Dispose() { } @@ -158,6 +175,13 @@ public JsonModelConverter(System.ClientModel.Primitives.ModelReaderWriterOptions public override System.ClientModel.Primitives.IJsonModel Read(ref System.Text.Json.Utf8JsonReader reader, System.Type typeToConvert, System.Text.Json.JsonSerializerOptions options) { throw null; } public override void Write(System.Text.Json.Utf8JsonWriter writer, System.ClientModel.Primitives.IJsonModel value, System.Text.Json.JsonSerializerOptions options) { } } + public partial class MessageLoggingPolicy : System.ClientModel.Primitives.PipelinePolicy + { + public MessageLoggingPolicy(System.ClientModel.Primitives.ClientLoggingOptions? options = null) { } + public static System.ClientModel.Primitives.MessageLoggingPolicy Default { get { throw null; } } + public sealed override void Process(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { } + public sealed override System.Threading.Tasks.ValueTask ProcessAsync(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { throw null; } + } public static partial class ModelReaderWriter { public static object? Read(System.BinaryData data, [System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembersAttribute(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.NonPublicConstructors | System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] System.Type returnType, System.ClientModel.Primitives.ModelReaderWriterOptions? options = null) { throw null; } @@ -279,6 +303,7 @@ protected PipelineResponseHeaders() { } public abstract partial class PipelineTransport : System.ClientModel.Primitives.PipelinePolicy { protected PipelineTransport() { } + protected PipelineTransport(bool enableLogging, Microsoft.Extensions.Logging.ILoggerFactory? loggerFactory) { } public System.ClientModel.Primitives.PipelineMessage CreateMessage() { throw null; } protected abstract System.ClientModel.Primitives.PipelineMessage CreateMessageCore(); public void Process(System.ClientModel.Primitives.PipelineMessage message) { } diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.net8.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.net8.0.cs index 675d7c513786..2cce8ba3d9f0 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.net8.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.net8.0.cs @@ -83,6 +83,19 @@ public enum ClientErrorBehaviors Default = 0, NoThrow = 1, } + public partial class ClientLoggingOptions + { + public ClientLoggingOptions() { } + public System.Collections.Generic.IList AllowedHeaderNames { get { throw null; } } + public System.Collections.Generic.IList AllowedQueryParameters { get { throw null; } } + public bool? EnableLogging { get { throw null; } set { } } + public bool? EnableMessageContentLogging { get { throw null; } set { } } + public bool? EnableMessageLogging { get { throw null; } set { } } + public Microsoft.Extensions.Logging.ILoggerFactory? LoggerFactory { get { throw null; } set { } } + public int? MessageContentSizeLimit { get { throw null; } set { } } + protected void AssertNotFrozen() { } + public virtual void Freeze() { } + } public sealed partial class ClientPipeline { internal ClientPipeline() { } @@ -95,6 +108,8 @@ public void Send(System.ClientModel.Primitives.PipelineMessage message) { } public partial class ClientPipelineOptions { public ClientPipelineOptions() { } + public System.ClientModel.Primitives.ClientLoggingOptions? ClientLoggingOptions { get { throw null; } set { } } + public System.ClientModel.Primitives.PipelinePolicy? MessageLoggingPolicy { get { throw null; } set { } } public System.TimeSpan? NetworkTimeout { get { throw null; } set { } } public System.ClientModel.Primitives.PipelinePolicy? RetryPolicy { get { throw null; } set { } } public System.ClientModel.Primitives.PipelineTransport? Transport { get { throw null; } set { } } @@ -105,6 +120,7 @@ public virtual void Freeze() { } public partial class ClientRetryPolicy : System.ClientModel.Primitives.PipelinePolicy { public ClientRetryPolicy(int maxRetries = 3) { } + public ClientRetryPolicy(int maxRetries, bool enableLogging, Microsoft.Extensions.Logging.ILoggerFactory? loggerFactory) { } public static System.ClientModel.Primitives.ClientRetryPolicy Default { get { throw null; } } protected virtual System.TimeSpan GetNextDelay(System.ClientModel.Primitives.PipelineMessage message, int tryCount) { throw null; } protected virtual void OnRequestSent(System.ClientModel.Primitives.PipelineMessage message) { } @@ -129,6 +145,7 @@ public partial class HttpClientPipelineTransport : System.ClientModel.Primitives { public HttpClientPipelineTransport() { } public HttpClientPipelineTransport(System.Net.Http.HttpClient client) { } + public HttpClientPipelineTransport(System.Net.Http.HttpClient? client, bool enableLogging, Microsoft.Extensions.Logging.ILoggerFactory? loggerFactory) { } public static System.ClientModel.Primitives.HttpClientPipelineTransport Shared { get { throw null; } } protected override System.ClientModel.Primitives.PipelineMessage CreateMessageCore() { throw null; } public void Dispose() { } @@ -158,6 +175,13 @@ public JsonModelConverter(System.ClientModel.Primitives.ModelReaderWriterOptions public override System.ClientModel.Primitives.IJsonModel Read(ref System.Text.Json.Utf8JsonReader reader, System.Type typeToConvert, System.Text.Json.JsonSerializerOptions options) { throw null; } public override void Write(System.Text.Json.Utf8JsonWriter writer, System.ClientModel.Primitives.IJsonModel value, System.Text.Json.JsonSerializerOptions options) { } } + public partial class MessageLoggingPolicy : System.ClientModel.Primitives.PipelinePolicy + { + public MessageLoggingPolicy(System.ClientModel.Primitives.ClientLoggingOptions? options = null) { } + public static System.ClientModel.Primitives.MessageLoggingPolicy Default { get { throw null; } } + public sealed override void Process(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { } + public sealed override System.Threading.Tasks.ValueTask ProcessAsync(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { throw null; } + } public static partial class ModelReaderWriter { public static object? Read(System.BinaryData data, [System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembersAttribute(System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.NonPublicConstructors | System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] System.Type returnType, System.ClientModel.Primitives.ModelReaderWriterOptions? options = null) { throw null; } @@ -279,6 +303,7 @@ protected PipelineResponseHeaders() { } public abstract partial class PipelineTransport : System.ClientModel.Primitives.PipelinePolicy { protected PipelineTransport() { } + protected PipelineTransport(bool enableLogging, Microsoft.Extensions.Logging.ILoggerFactory? loggerFactory) { } public System.ClientModel.Primitives.PipelineMessage CreateMessage() { throw null; } protected abstract System.ClientModel.Primitives.PipelineMessage CreateMessageCore(); public void Process(System.ClientModel.Primitives.PipelineMessage message) { } diff --git a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs index ce85c5d8b83d..c26dae86758c 100644 --- a/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs +++ b/sdk/core/System.ClientModel/api/System.ClientModel.netstandard2.0.cs @@ -83,6 +83,19 @@ public enum ClientErrorBehaviors Default = 0, NoThrow = 1, } + public partial class ClientLoggingOptions + { + public ClientLoggingOptions() { } + public System.Collections.Generic.IList AllowedHeaderNames { get { throw null; } } + public System.Collections.Generic.IList AllowedQueryParameters { get { throw null; } } + public bool? EnableLogging { get { throw null; } set { } } + public bool? EnableMessageContentLogging { get { throw null; } set { } } + public bool? EnableMessageLogging { get { throw null; } set { } } + public Microsoft.Extensions.Logging.ILoggerFactory? LoggerFactory { get { throw null; } set { } } + public int? MessageContentSizeLimit { get { throw null; } set { } } + protected void AssertNotFrozen() { } + public virtual void Freeze() { } + } public sealed partial class ClientPipeline { internal ClientPipeline() { } @@ -95,6 +108,8 @@ public void Send(System.ClientModel.Primitives.PipelineMessage message) { } public partial class ClientPipelineOptions { public ClientPipelineOptions() { } + public System.ClientModel.Primitives.ClientLoggingOptions? ClientLoggingOptions { get { throw null; } set { } } + public System.ClientModel.Primitives.PipelinePolicy? MessageLoggingPolicy { get { throw null; } set { } } public System.TimeSpan? NetworkTimeout { get { throw null; } set { } } public System.ClientModel.Primitives.PipelinePolicy? RetryPolicy { get { throw null; } set { } } public System.ClientModel.Primitives.PipelineTransport? Transport { get { throw null; } set { } } @@ -105,6 +120,7 @@ public virtual void Freeze() { } public partial class ClientRetryPolicy : System.ClientModel.Primitives.PipelinePolicy { public ClientRetryPolicy(int maxRetries = 3) { } + public ClientRetryPolicy(int maxRetries, bool enableLogging, Microsoft.Extensions.Logging.ILoggerFactory? loggerFactory) { } public static System.ClientModel.Primitives.ClientRetryPolicy Default { get { throw null; } } protected virtual System.TimeSpan GetNextDelay(System.ClientModel.Primitives.PipelineMessage message, int tryCount) { throw null; } protected virtual void OnRequestSent(System.ClientModel.Primitives.PipelineMessage message) { } @@ -129,6 +145,7 @@ public partial class HttpClientPipelineTransport : System.ClientModel.Primitives { public HttpClientPipelineTransport() { } public HttpClientPipelineTransport(System.Net.Http.HttpClient client) { } + public HttpClientPipelineTransport(System.Net.Http.HttpClient? client, bool enableLogging, Microsoft.Extensions.Logging.ILoggerFactory? loggerFactory) { } public static System.ClientModel.Primitives.HttpClientPipelineTransport Shared { get { throw null; } } protected override System.ClientModel.Primitives.PipelineMessage CreateMessageCore() { throw null; } public void Dispose() { } @@ -157,6 +174,13 @@ public JsonModelConverter(System.ClientModel.Primitives.ModelReaderWriterOptions public override System.ClientModel.Primitives.IJsonModel Read(ref System.Text.Json.Utf8JsonReader reader, System.Type typeToConvert, System.Text.Json.JsonSerializerOptions options) { throw null; } public override void Write(System.Text.Json.Utf8JsonWriter writer, System.ClientModel.Primitives.IJsonModel value, System.Text.Json.JsonSerializerOptions options) { } } + public partial class MessageLoggingPolicy : System.ClientModel.Primitives.PipelinePolicy + { + public MessageLoggingPolicy(System.ClientModel.Primitives.ClientLoggingOptions? options = null) { } + public static System.ClientModel.Primitives.MessageLoggingPolicy Default { get { throw null; } } + public sealed override void Process(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { } + public sealed override System.Threading.Tasks.ValueTask ProcessAsync(System.ClientModel.Primitives.PipelineMessage message, System.Collections.Generic.IReadOnlyList pipeline, int currentIndex) { throw null; } + } public static partial class ModelReaderWriter { public static object? Read(System.BinaryData data, System.Type returnType, System.ClientModel.Primitives.ModelReaderWriterOptions? options = null) { throw null; } @@ -277,6 +301,7 @@ protected PipelineResponseHeaders() { } public abstract partial class PipelineTransport : System.ClientModel.Primitives.PipelinePolicy { protected PipelineTransport() { } + protected PipelineTransport(bool enableLogging, Microsoft.Extensions.Logging.ILoggerFactory? loggerFactory) { } public System.ClientModel.Primitives.PipelineMessage CreateMessage() { throw null; } protected abstract System.ClientModel.Primitives.PipelineMessage CreateMessageCore(); public void Process(System.ClientModel.Primitives.PipelineMessage message) { } diff --git a/sdk/core/System.ClientModel/samples/Logging.md b/sdk/core/System.ClientModel/samples/Logging.md new file mode 100644 index 000000000000..c36f2360b3cb --- /dev/null +++ b/sdk/core/System.ClientModel/samples/Logging.md @@ -0,0 +1,158 @@ +# System.ClientModel-based client logging samples + +## Introduction + +Clients built on `System.ClientModel` emit log messages by default. These log messages include information about HTTP message requests and responses, retries, exceptions thrown in the transport, and delays in receiving responses. + +Clients can be configured to completely disable logging, or enable additional log messages. + +By default, logs are written to Event Source. Clients can be configured to write logs to ILogger instead by providing an ILoggerFactory to the client in `ClientLoggingOptions`. + +## Using ILoggerFactory to capture logs + +Here is an example of how a client can be configured to use `ILogger`. Information about acquiring or defining an `ILoggerFactory` can be found in the documentation for [`Microsoft.Extensions.Logging`](https://learn.microsoft.com/dotnet/core/extensions/logging?tabs=command-line). + +This method of creating an `ILoggerFactory` is a trivial example and is only suitable for simple console apps. + +```C# Snippet:UseILoggerFactoryToCaptureLogs +using ILoggerFactory factory = LoggerFactory.Create(builder => +{ + builder.AddConsole().SetMinimumLevel(LogLevel.Information); +}); + +ClientLoggingOptions loggingOptions = new() +{ + LoggerFactory = factory +}; + +MapsClientOptions options = new() +{ + ClientLoggingOptions = loggingOptions +}; + +// Create and use client as usual +``` + +Some sensitive headers and query parameters are not logged by default and are displayed as "REDACTED". To include them in logs add them to `ClientLoggingOptions.AllowedHeaderNames` or `ClientLoggingOptions.AllowedQueryParameters`. + +```C# Snippet:LoggingRedactedHeaderILogger +using ILoggerFactory factory = LoggerFactory.Create(builder => +{ + builder.AddConsole(); +}); + +ClientLoggingOptions loggingOptions = new() +{ + LoggerFactory = factory +}; +loggingOptions.AllowedHeaderNames.Add("Request-Id"); +loggingOptions.AllowedQueryParameters.Add("api-version"); + +MapsClientOptions options = new() +{ + ClientLoggingOptions = loggingOptions +}; +``` + +You can also disable redaction completely by adding a `"*"` to `ClientLoggingOptions.AllowedHeaderNames` or `ClientLoggingOptions.AllowedQueryParameters`. + +```C# Snippet:LoggingAllRedactedHeadersILogger +using ILoggerFactory factory = LoggerFactory.Create(builder => +{ + builder.AddConsole(); +}); + +ClientLoggingOptions loggingOptions = new() +{ + LoggerFactory = factory +}; +loggingOptions.AllowedHeaderNames.Add("*"); +loggingOptions.AllowedQueryParameters.Add("*"); + +MapsClientOptions options = new() +{ + ClientLoggingOptions = loggingOptions +}; +``` + +By default, only URI and header names are logged. To enable content logging, set the logging level to `LogLevel.Debug` and set the `ClientLoggingOptions.EnableMessageContentLogging` client option: + + +```C# Snippet:EnableContentLoggingILogger +using ILoggerFactory factory = LoggerFactory.Create(builder => +{ + builder.AddConsole().SetMinimumLevel(LogLevel.Debug); +}); + +ClientLoggingOptions loggingOptions = new() +{ + LoggerFactory = factory, + EnableMessageContentLogging = true +}; + +MapsClientOptions options = new() +{ + ClientLoggingOptions = loggingOptions +}; +``` + +## Using Event Source to capture logs + +If an `ILoggerFactory` is not provided to the client, and logging is enabled, logs will be written to Event Source. The name of the Event Source is "System.ClientModel". Event Source logs can be collected in a few ways, as described in the [Event Source documentation for collecting traces](https://learn.microsoft.com/dotnet/core/diagnostics/eventsource-collect-and-view-traces). + +This sample uses an Event Listener to collect logs. It uses the `ConsoleWriterEventListener` as defined in the [EventListener section](https://learn.microsoft.com/dotnet/core/diagnostics/eventsource-collect-and-view-traces#eventlistener) of the Event Source documentation above. + +```C# Snippet:UseEventSourceToCaptureLogs +// In order for an event listener to collect logs, it must be in scope and active +// while the client library is in use. If the listener is disposed or otherwise +// out of scope, logs cannot be collected. +using ConsoleWriterEventListener listener = new(); + +// Create and use client as usual +``` + +Some sensitive headers and query parameters are not logged by default and are displayed as "REDACTED". To include them in logs add them to `ClientLoggingOptions.AllowedHeaderNames` or `ClientLoggingOptions.AllowedQueryParameters`. + +```C# Snippet:LoggingRedactedHeaderEventSource +using ConsoleWriterEventListener listener = new(); + +ClientLoggingOptions loggingOptions = new(); +loggingOptions.AllowedHeaderNames.Add("Request-Id"); +loggingOptions.AllowedQueryParameters.Add("api-version"); + +MapsClientOptions options = new() +{ + ClientLoggingOptions = loggingOptions +}; +``` + +You can also disable redaction completely by adding a `"*"` to `ClientLoggingOptions.AllowedHeaderNames` or `ClientLoggingOptions.AllowedQueryParameters`. + +```C# Snippet:LoggingAllRedactedHeadersEventSource +using ConsoleWriterEventListener listener = new(); + +ClientLoggingOptions loggingOptions = new(); +loggingOptions.AllowedHeaderNames.Add("*"); +loggingOptions.AllowedQueryParameters.Add("*"); + +MapsClientOptions options = new() +{ + ClientLoggingOptions = loggingOptions +}; +``` + +By default only URI and headers are logged. To enable content logging, set the logging level to `EventLevel.Verbose` and set the `ClientLoggingOptions.EnableMessageContentLogging` client option: + +```C# Snippet:EnableContentLoggingEventSource +using ConsoleWriterEventListener listener = new(); + +ClientLoggingOptions loggingOptions = new() +{ + EnableMessageContentLogging = true +}; + +MapsClientOptions options = new() +{ + ClientLoggingOptions = loggingOptions +}; +``` diff --git a/sdk/core/System.ClientModel/src/Internal/ChangeTrackingStringList.cs b/sdk/core/System.ClientModel/src/Internal/ChangeTrackingStringList.cs new file mode 100644 index 000000000000..7bfbc246e7f9 --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/ChangeTrackingStringList.cs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections; +using System.Collections.Generic; +using System.Collections.ObjectModel; + +namespace System.ClientModel.Internal; + +internal class ChangeTrackingStringList : IList +{ + private IList _list; + private bool _frozen = false; + private bool _tracking = true; + + public ChangeTrackingStringList() + { + _list = []; + } + + public ChangeTrackingStringList(IEnumerable collection) + { + _list = new List(collection); + } + + public bool HasChanged { get; private set; } + + public void Freeze() + { + _frozen = true; + } + + public void AssertNotFrozen() + { + if (_frozen) + { + throw new InvalidOperationException("Cannot change any client pipeline options after the client pipeline has been created."); + } + } + + #region IList implementation + + public string this[int index] + { + get => _list[index]; + set + { + AssertNotFrozen(); + _list[index] = value; + + HasChanged |= _tracking; + } + } + + public int Count => _list.Count; + + public bool IsReadOnly => _list.IsReadOnly; + + public void Add(string item) + { + AssertNotFrozen(); + _list.Add(item); + + HasChanged |= _tracking; + } + + public void Clear() + { + AssertNotFrozen(); + int count = _list.Count; + + _list.Clear(); + + HasChanged |= _tracking && (count != 0); + } + + public bool Contains(string item) => _list.Contains(item); + + public void CopyTo(string[] array, int arrayIndex) => _list.CopyTo(array, arrayIndex); + + public IEnumerator GetEnumerator() => _list.GetEnumerator(); + + public int IndexOf(string item) => _list.IndexOf(item); + + public void Insert(int index, string item) + { + AssertNotFrozen(); + _list.Insert(index, item); + + HasChanged |= _tracking; + } + + public bool Remove(string item) + { + AssertNotFrozen(); + bool removed = _list.Remove(item); + + HasChanged |= _tracking && removed; + + return removed; + } + + public void RemoveAt(int index) + { + AssertNotFrozen(); + _list.RemoveAt(index); + + HasChanged |= _tracking; + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + #endregion +} diff --git a/sdk/core/System.ClientModel/src/Internal/ContentTypeUtilities.cs b/sdk/core/System.ClientModel/src/Internal/ContentTypeUtilities.cs new file mode 100644 index 000000000000..d2c3eef84a58 --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/ContentTypeUtilities.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Text; + +namespace System.ClientModel.Internal; + +internal class ContentTypeUtilities +{ + public static bool TryGetTextEncoding(string contentType, out Encoding? encoding) + { + const string charsetMarker = "; charset="; + const string utf8Charset = "utf-8"; + const string textContentTypePrefix = "text/"; + const string jsonSuffix = "json"; + const string appJsonPrefix = "application/json"; + const string xmlSuffix = "xml"; + const string urlEncodedSuffix = "-urlencoded"; + + // Default is technically US-ASCII, but will default to UTF-8 which is a superset. + const string appFormUrlEncoded = "application/x-www-form-urlencoded"; + + if (contentType == null) + { + encoding = null; + return false; + } + + var charsetIndex = contentType.IndexOf(charsetMarker, StringComparison.OrdinalIgnoreCase); + if (charsetIndex != -1) + { + ReadOnlySpan charset = contentType.AsSpan().Slice(charsetIndex + charsetMarker.Length); + if (charset.StartsWith(utf8Charset.AsSpan(), StringComparison.OrdinalIgnoreCase)) + { + encoding = Encoding.UTF8; + return true; + } + } + + if (contentType.StartsWith(textContentTypePrefix, StringComparison.OrdinalIgnoreCase) || + contentType.EndsWith(jsonSuffix, StringComparison.OrdinalIgnoreCase) || + contentType.EndsWith(xmlSuffix, StringComparison.OrdinalIgnoreCase) || + contentType.EndsWith(urlEncodedSuffix, StringComparison.OrdinalIgnoreCase) || + contentType.StartsWith(appJsonPrefix, StringComparison.OrdinalIgnoreCase) || + contentType.StartsWith(appFormUrlEncoded, StringComparison.OrdinalIgnoreCase)) + { + encoding = Encoding.UTF8; + return true; + } + + encoding = null; + return false; + } +} diff --git a/sdk/core/System.ClientModel/src/Internal/Logging/ClientModelEventSource.cs b/sdk/core/System.ClientModel/src/Internal/Logging/ClientModelEventSource.cs new file mode 100644 index 000000000000..f84d5714f72b --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/Logging/ClientModelEventSource.cs @@ -0,0 +1,285 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Tracing; +using System.Text; + +namespace System.ClientModel.Internal; + +// The methods in this class should only ever be called from PipelineMessageLogger, PipelineRetryLogger, or PipelineTransportLogger +[EventSource(Name = "System.ClientModel")] +internal sealed class ClientModelEventSource : EventSource +{ + private ClientModelEventSource(string eventSourceName, string[]? traits = default) : base(eventSourceName, EventSourceSettings.Default, traits) { } + + public static ClientModelEventSource Log = new("System.ClientModel"); + + #region Request + + [NonEvent] + public void Request(string requestId, PipelineRequest request, string? clientAssembly, PipelineMessageSanitizer sanitizer) + { + if (IsEnabled(EventLevel.Informational, EventKeywords.None)) + { + Request(requestId, request.Method, sanitizer.SanitizeUrl(request.Uri!.AbsoluteUri), FormatHeaders(request.Headers, sanitizer), clientAssembly); + } + } + + [Event(LoggingEventIds.RequestEvent, Level = EventLevel.Informational, Message = "Request [{0}] {1} {2}\r\n{3}client assembly: {4}")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "WriteEvent is used with primitive types.")] + private void Request(string? requestId, string method, string uri, string headers, string? clientAssembly) + { + WriteEvent(LoggingEventIds.RequestEvent, requestId, method, uri, headers, clientAssembly); + } + + [NonEvent] + public void RequestContent(string requestId, byte[] content, Encoding? textEncoding) + { + if (IsEnabled(EventLevel.Verbose, EventKeywords.None)) + { + if (textEncoding != null) + { + RequestContentText(requestId, textEncoding.GetString(content)); + } + else + { + RequestContent(requestId, content); + } + } + } + + [Event(LoggingEventIds.RequestContentEvent, Level = EventLevel.Verbose, Message = "Request [{0}] content: {1}")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "WriteEvent is used with an array with primitive type elements.")] + private void RequestContent(string? requestId, byte[] content) + { + WriteEvent(LoggingEventIds.RequestContentEvent, requestId, content); + } + + [Event(LoggingEventIds.RequestContentTextEvent, Level = EventLevel.Verbose, Message = "Request [{0}] content: {1}")] + private void RequestContentText(string? requestId, string content) + { + WriteEvent(LoggingEventIds.RequestContentTextEvent, requestId, content); + } + + #endregion + + #region Response + + [NonEvent] + public void Response(string requestId, PipelineResponse response, double seconds, PipelineMessageSanitizer sanitizer) + { + if (IsEnabled(EventLevel.Informational, EventKeywords.None)) + { + Response(requestId, response.Status, response.ReasonPhrase, FormatHeaders(response.Headers, sanitizer), seconds); + } + } + + [Event(LoggingEventIds.ResponseEvent, Level = EventLevel.Informational, Message = "Response [{0}] {1} {2} ({4:00.0}s)\r\n{3}")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "WriteEvent is used with primitive types.")] + private void Response(string? requestId, int status, string reasonPhrase, string headers, double seconds) + { + WriteEvent(LoggingEventIds.ResponseEvent, requestId, status, reasonPhrase, headers, seconds); + } + + [NonEvent] + public void ResponseContent(string requestId, byte[] content, Encoding? textEncoding) + { + if (IsEnabled(EventLevel.Verbose, EventKeywords.None)) + { + if (textEncoding is not null) + { + ResponseContentText(requestId, textEncoding.GetString(content)); + } + else + { + ResponseContent(requestId, content); + } + } + } + + [Event(LoggingEventIds.ResponseContentEvent, Level = EventLevel.Verbose, Message = "Response [{0}] content: {1}")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "WriteEvent is used with an array with primitive type elements.")] + private void ResponseContent(string? requestId, byte[] content) + { + WriteEvent(LoggingEventIds.ResponseContentEvent, requestId, content); + } + + [Event(LoggingEventIds.ResponseContentTextEvent, Level = EventLevel.Verbose, Message = "Response [{0}] content: {1}")] + private void ResponseContentText(string? requestId, string content) + { + WriteEvent(LoggingEventIds.ResponseContentTextEvent, requestId, content); + } + + [NonEvent] + public void ResponseContentBlock(string requestId, int blockNumber, byte[] content, Encoding? textEncoding) + { + if (IsEnabled(EventLevel.Verbose, EventKeywords.None)) + { + if (textEncoding is not null) + { + ResponseContentTextBlock(requestId, blockNumber, textEncoding.GetString(content)); + } + else + { + ResponseContentBlock(requestId, blockNumber, content); + } + } + } + + [Event(LoggingEventIds.ResponseContentBlockEvent, Level = EventLevel.Verbose, Message = "Response [{0}] content block {1}: {2}")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "WriteEvent is used with an array with primitive type elements.")] + private void ResponseContentBlock(string? requestId, int blockNumber, byte[] content) + { + WriteEvent(LoggingEventIds.ResponseContentBlockEvent, requestId, blockNumber, content); + } + + [Event(LoggingEventIds.ResponseContentTextBlockEvent, Level = EventLevel.Verbose, Message = "Response [{0}] content block {1}: {2}")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "WriteEvent is used with primitive types.")] + private void ResponseContentTextBlock(string? requestId, int blockNumber, string content) + { + WriteEvent(LoggingEventIds.ResponseContentTextBlockEvent, requestId, blockNumber, content); + } + + #endregion + + #region Error Response + + [NonEvent] + public void ErrorResponse(string requestId, PipelineResponse response, double elapsed, PipelineMessageSanitizer sanitizer) + { + if (IsEnabled(EventLevel.Warning, EventKeywords.None)) + { + ErrorResponse(requestId, response.Status, response.ReasonPhrase, FormatHeaders(response.Headers, sanitizer), elapsed); + } + } + + [Event(LoggingEventIds.ErrorResponseEvent, Level = EventLevel.Warning, Message = "Error response [{0}] {1} {2} ({4:00.0}s)\r\n{3}")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "WriteEvent is used with primitive types.")] + private void ErrorResponse(string? requestId, int status, string reasonPhrase, string headers, double seconds) + { + WriteEvent(LoggingEventIds.ErrorResponseEvent, requestId, status, reasonPhrase, headers, seconds); + } + + [NonEvent] + public void ErrorResponseContent(string requestId, byte[] content, Encoding? textEncoding) + { + if (IsEnabled(EventLevel.Informational, EventKeywords.None)) + { + if (textEncoding is not null) + { + ErrorResponseContentText(requestId, textEncoding.GetString(content)); + } + else + { + ErrorResponseContent(requestId, content); + } + } + } + + [Event(LoggingEventIds.ErrorResponseContentEvent, Level = EventLevel.Informational, Message = "Error response [{0}] content: {1}")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "WriteEvent is used with an array with primitive type elements.")] + private void ErrorResponseContent(string? requestId, byte[] content) + { + WriteEvent(LoggingEventIds.ErrorResponseContentEvent, requestId, content); + } + + [Event(LoggingEventIds.ErrorResponseContentTextEvent, Level = EventLevel.Informational, Message = "Error response [{0}] content: {1}")] + private void ErrorResponseContentText(string? requestId, string content) + { + WriteEvent(LoggingEventIds.ErrorResponseContentTextEvent, requestId, content); + } + + [NonEvent] + public void ErrorResponseContentBlock(string requestId, int blockNumber, byte[] content, Encoding? textEncoding) + { + if (IsEnabled(EventLevel.Informational, EventKeywords.None)) + { + if (textEncoding is not null) + { + ErrorResponseContentTextBlock(requestId, blockNumber, textEncoding.GetString(content)); + } + else + { + ErrorResponseContentBlock(requestId, blockNumber, content); + } + } + } + + [Event(LoggingEventIds.ErrorResponseContentBlockEvent, Level = EventLevel.Informational, Message = "Error response [{0}] content block {1}: {2}")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "WriteEvent is used with an array with primitive type elements.")] + private void ErrorResponseContentBlock(string? requestId, int blockNumber, byte[] content) + { + WriteEvent(LoggingEventIds.ErrorResponseContentBlockEvent, requestId, blockNumber, content); + } + + [Event(LoggingEventIds.ErrorResponseContentTextBlockEvent, Level = EventLevel.Informational, Message = "Error response [{0}] content block {1}: {2}")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "WriteEvent is used with primitive types.")] + private void ErrorResponseContentTextBlock(string? requestId, int blockNumber, string content) + { + WriteEvent(LoggingEventIds.ErrorResponseContentTextBlockEvent, requestId, blockNumber, content); + } + + #endregion + + #region Retry + + [Event(LoggingEventIds.RequestRetryingEvent, Level = EventLevel.Informational, Message = "Request [{0}] attempt number {1} took {2:00.0}s")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "WriteEvent is used with primitive types.")] + public void RequestRetrying(string? requestId, int retryNumber, double seconds) + { + if (IsEnabled(EventLevel.Informational, EventKeywords.None)) + { + WriteEvent(LoggingEventIds.RequestRetryingEvent, requestId, retryNumber, seconds); + } + } + + #endregion + + #region Response Delay + + [Event(LoggingEventIds.ResponseDelayEvent, Level = EventLevel.Warning, Message = "Response [{0}] took {1:00.0}s")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "WriteEvent is used with primitive types.")] + public void ResponseDelay(string? requestId, double seconds) + { + if (IsEnabled(EventLevel.Warning, EventKeywords.None)) + { + WriteEvent(LoggingEventIds.ResponseDelayEvent, requestId, seconds); + } + } + + #endregion + + #region Exception Response + + [Event(LoggingEventIds.ExceptionResponseEvent, Level = EventLevel.Informational, Message = "Request [{0}] exception {1}")] + public void ExceptionResponse(string? requestId, string exception) + { + if (IsEnabled(EventLevel.Informational, EventKeywords.None)) + { + WriteEvent(LoggingEventIds.ExceptionResponseEvent, requestId, exception); + } + } + + #endregion + + #region Helpers + + [NonEvent] + private string FormatHeaders(IEnumerable> headers, PipelineMessageSanitizer sanitizer) + { + var stringBuilder = new StringBuilder(); + foreach (var header in headers) + { + stringBuilder.Append(header.Key); + stringBuilder.Append(':'); + stringBuilder.Append(sanitizer.SanitizeHeader(header.Key, header.Value)); + stringBuilder.Append(Environment.NewLine); + } + return stringBuilder.ToString(); + } + + #endregion +} diff --git a/sdk/core/System.ClientModel/src/Internal/Logging/LoggingEventIds.cs b/sdk/core/System.ClientModel/src/Internal/Logging/LoggingEventIds.cs new file mode 100644 index 000000000000..f121ba4af494 --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/Logging/LoggingEventIds.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace System.ClientModel.Internal; + +internal class LoggingEventIds +{ + public const int RequestEvent = 1; + public const int RequestContentEvent = 2; + public const int ResponseEvent = 5; + public const int ResponseContentEvent = 6; + public const int ResponseDelayEvent = 7; + public const int ErrorResponseEvent = 8; + public const int ErrorResponseContentEvent = 9; + public const int RequestRetryingEvent = 10; + public const int ResponseContentBlockEvent = 11; + public const int ErrorResponseContentBlockEvent = 12; + public const int ResponseContentTextEvent = 13; + public const int ErrorResponseContentTextEvent = 14; + public const int ResponseContentTextBlockEvent = 15; + public const int ErrorResponseContentTextBlockEvent = 16; + public const int RequestContentTextEvent = 17; + public const int ExceptionResponseEvent = 18; +} diff --git a/sdk/core/System.ClientModel/src/Internal/Logging/PipelineMessageHeadersLogValue.cs b/sdk/core/System.ClientModel/src/Internal/Logging/PipelineMessageHeadersLogValue.cs new file mode 100644 index 000000000000..8f7df0ece01c --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/Logging/PipelineMessageHeadersLogValue.cs @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics.Tracing; +using System.Text; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +namespace System.ClientModel.Internal; + +internal class PipelineMessageHeadersLogValue : IReadOnlyList> +{ + private readonly PipelineRequestHeaders? _requestHeaders; + private readonly PipelineResponseHeaders? _responseHeaders; + private readonly PipelineMessageSanitizer _sanitizer; + + private List>? _values; + private string? _formatted; + + public PipelineMessageHeadersLogValue(PipelineRequestHeaders headers, PipelineMessageSanitizer sanitizer) + { + _sanitizer = sanitizer; + _requestHeaders = headers; + } + + public PipelineMessageHeadersLogValue(PipelineResponseHeaders headers, PipelineMessageSanitizer sanitizer) + { + _sanitizer = sanitizer; + _responseHeaders = headers; + } + + private List> Values + { + get + { + if (_values == null) + { + var values = new List>(); + + if (_requestHeaders != null) + { + foreach (KeyValuePair kvp in _requestHeaders) + { + values.Add(new KeyValuePair(kvp.Key, _sanitizer.SanitizeHeader(kvp.Key, kvp.Value))); + } + } + else if (_responseHeaders != null) + { + foreach (KeyValuePair kvp in _responseHeaders) + { + values.Add(new KeyValuePair(kvp.Key, _sanitizer.SanitizeHeader(kvp.Key, kvp.Value))); + } + } + + _values = values; + } + + return _values; + } + } + + public KeyValuePair this[int index] + { + get + { + if (index < 0 || index >= Count) + { + throw new IndexOutOfRangeException(nameof(index)); + } + + return Values[index]; + } + } + + public int Count => Values.Count; + + public IEnumerator> GetEnumerator() + { + return Values.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return Values.GetEnumerator(); + } + + public override string ToString() + { + if (_formatted == null) + { + var builder = new StringBuilder(); + + foreach (KeyValuePair header in Values) + { + builder.Append(header.Key); + builder.Append(':'); + builder.Append(_sanitizer.SanitizeHeader(header.Key, header.Value)); + builder.Append(Environment.NewLine); + } + + _formatted = builder.ToString(); + } + + return _formatted; + } +} diff --git a/sdk/core/System.ClientModel/src/Internal/Logging/PipelineMessageLogger.cs b/sdk/core/System.ClientModel/src/Internal/Logging/PipelineMessageLogger.cs new file mode 100644 index 000000000000..da86c233b907 --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/Logging/PipelineMessageLogger.cs @@ -0,0 +1,241 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Diagnostics.Tracing; +using System.Text; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +namespace System.ClientModel.Internal; + +internal partial class PipelineMessageLogger +{ + private readonly ILogger? _logger; + private readonly PipelineMessageSanitizer _sanitizer; + + public PipelineMessageLogger(PipelineMessageSanitizer sanitizer, ILoggerFactory? loggerFactory) + { + _sanitizer = sanitizer; + _logger = loggerFactory?.CreateLogger() ?? null; + } + + /// + /// Whether the given log level or event level is enabled, depending + /// on whether this handler logs to ILogger or Event Source. Should be + /// used to guard expensive operations. + /// + /// The LogLevel to log to for ILogger. If an ILogger WAS NOT provided to the constructor, this value will be ignored. + /// The EventLevel to log to for EventSource. If an ILogger WAS provided to the constructor, this value will be ignored. + public bool IsEnabled(LogLevel logLevel, EventLevel eventLevel) + { + return _logger is not null ? _logger.IsEnabled(logLevel) : ClientModelEventSource.Log.IsEnabled(eventLevel, EventKeywords.None); + } + + #region Request + + public void LogRequest(string requestId, PipelineRequest request, string? clientAssembly) + { + if (_logger is not null) + { + if (_logger.IsEnabled(LogLevel.Information)) + { + Request(_logger, requestId, request.Method, _sanitizer.SanitizeUrl(request.Uri!.AbsoluteUri), new PipelineMessageHeadersLogValue(request.Headers, _sanitizer), clientAssembly); + } + } + else + { + ClientModelEventSource.Log.Request(requestId, request, clientAssembly, _sanitizer); + } + } + + [LoggerMessage(LoggingEventIds.RequestEvent, LogLevel.Information, "Request [{requestId}] {method} {uri}\r\n{headers}client assembly: {clientAssembly}", SkipEnabledCheck = true, EventName = "Request")] + private static partial void Request(ILogger logger, string requestId, string method, string uri, PipelineMessageHeadersLogValue headers, string? clientAssembly); + + public void LogRequestContent(string requestId, byte[] content, Encoding? textEncoding) + { + if (_logger is not null) + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (textEncoding != null) + { + RequestContentText(_logger, requestId, textEncoding.GetString(content)); + } + else + { + RequestContent(_logger, requestId, content); + } + } + } + else + { + ClientModelEventSource.Log.RequestContent(requestId, content, textEncoding); + } + } + + [LoggerMessage(LoggingEventIds.RequestContentEvent, LogLevel.Debug, "Request [{requestId}] content: {content}", SkipEnabledCheck = true, EventName = "RequestContent")] + private static partial void RequestContent(ILogger logger, string requestId, byte[] content); + + [LoggerMessage(LoggingEventIds.RequestContentTextEvent, LogLevel.Debug, "Request [{requestId}] content: {content}", SkipEnabledCheck = true, EventName = "RequestContentText")] + private static partial void RequestContentText(ILogger logger, string requestId, string content); + + #endregion + + #region Response + + public void LogResponse(string requestId, PipelineResponse response, double seconds) + { + if (_logger is not null) + { + if (_logger.IsEnabled(LogLevel.Information)) + { + Response(_logger, requestId, response.Status, response.ReasonPhrase, new PipelineMessageHeadersLogValue(response.Headers, _sanitizer), seconds); + } + } + else + { + ClientModelEventSource.Log.Response(requestId, response, seconds, _sanitizer); + } + } + + [LoggerMessage(LoggingEventIds.ResponseEvent, LogLevel.Information, "Response [{requestId}] {status} {reasonPhrase} ({seconds:00.0}s)\r\n{headers}", SkipEnabledCheck = true, EventName = "Response")] + private static partial void Response(ILogger logger, string requestId, int status, string reasonPhrase, PipelineMessageHeadersLogValue headers, double seconds); + + public void LogResponseContent(string requestId, byte[] content, Encoding? textEncoding) + { + if (_logger is not null) + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (textEncoding != null) + { + ResponseContentText(_logger, requestId, textEncoding.GetString(content)); + } + else + { + ResponseContent(_logger, requestId, content); + } + } + } + else + { + ClientModelEventSource.Log.ResponseContent(requestId, content, textEncoding); + } + } + + [LoggerMessage(LoggingEventIds.ResponseContentEvent, LogLevel.Debug, "Response [{requestId}] content: {content}", SkipEnabledCheck = true, EventName = "ResponseContent")] + private static partial void ResponseContent(ILogger logger, string requestId, byte[] content); + + [LoggerMessage(LoggingEventIds.ResponseContentTextEvent, LogLevel.Debug, "Response [{requestId}] content: {content}", SkipEnabledCheck = true, EventName = "ResponseContentText")] + private static partial void ResponseContentText(ILogger logger, string requestId, string content); + + public void LogResponseContentBlock(string requestId, int blockNumber, byte[] content, Encoding? textEncoding) + { + if (_logger is not null) + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (textEncoding != null) + { + ResponseContentTextBlock(_logger, requestId, blockNumber, textEncoding.GetString(content)); + } + else + { + ResponseContentBlock(_logger, requestId, blockNumber, content); + } + } + } + else + { + ClientModelEventSource.Log.ResponseContentBlock(requestId, blockNumber, content, textEncoding); + } + } + + [LoggerMessage(LoggingEventIds.ResponseContentBlockEvent, LogLevel.Debug, "Response [{requestId}] content block {blockNumber}: {content}", SkipEnabledCheck = true, EventName = "ResponseContentBlock")] + private static partial void ResponseContentBlock(ILogger logger, string requestId, int blockNumber, byte[] content); + + [LoggerMessage(LoggingEventIds.ResponseContentTextBlockEvent, LogLevel.Debug, "Response [{requestId}] content block {blockNumber}: {content}", SkipEnabledCheck = true, EventName = "ResponseContentTextBlock")] + private static partial void ResponseContentTextBlock(ILogger logger, string requestId, int blockNumber, string content); + + #endregion + + #region Error Response + + public void LogErrorResponse(string requestId, PipelineResponse response, double seconds) + { + if (_logger is not null) + { + if (_logger.IsEnabled(LogLevel.Warning)) + { + ErrorResponse(_logger, requestId, response.Status, response.ReasonPhrase, new PipelineMessageHeadersLogValue(response.Headers, _sanitizer), seconds); + } + } + else + { + ClientModelEventSource.Log.ErrorResponse(requestId, response, seconds, _sanitizer); + } + } + + [LoggerMessage(LoggingEventIds.ErrorResponseEvent, LogLevel.Warning, "Error response [{requestId}] {status} {reasonPhrase} ({seconds:00.0}s)\r\n{headers}", SkipEnabledCheck = true, EventName = "ErrorResponse")] + private static partial void ErrorResponse(ILogger logger, string requestId, int status, string reasonPhrase, PipelineMessageHeadersLogValue headers, double seconds); + + public void LogErrorResponseContent(string requestId, byte[] content, Encoding? textEncoding) + { + if (_logger is not null) + { + if (_logger.IsEnabled(LogLevel.Information)) + { + if (textEncoding != null) + { + ErrorResponseContentText(_logger, requestId, textEncoding.GetString(content)); + } + else + { + ErrorResponseContent(_logger, requestId, content); + } + } + } + else + { + ClientModelEventSource.Log.ErrorResponseContent(requestId, content, textEncoding); + } + } + + [LoggerMessage(LoggingEventIds.ErrorResponseContentEvent, LogLevel.Information, "Error response [{requestId}] content: {content}", SkipEnabledCheck = true, EventName = "ErrorResponseContent")] + private static partial void ErrorResponseContent(ILogger logger, string requestId, byte[] content); + + [LoggerMessage(LoggingEventIds.ErrorResponseContentTextEvent, LogLevel.Information, "Error response [{requestId}] content: {content}", SkipEnabledCheck = true, EventName = "ErrorResponseContentText")] + private static partial void ErrorResponseContentText(ILogger logger, string requestId, string content); + + public void LogErrorResponseContentBlock(string requestId, int blockNumber, byte[] content, Encoding? textEncoding) + { + if (_logger is not null) + { + if (_logger.IsEnabled(LogLevel.Information)) + { + if (textEncoding != null) + { + ErrorResponseContentTextBlock(_logger, requestId, blockNumber, textEncoding.GetString(content)); + } + else + { + ErrorResponseContentBlock(_logger, requestId, blockNumber, content); + } + } + } + else + { + ClientModelEventSource.Log.ErrorResponseContentBlock(requestId, blockNumber, content, textEncoding); + } + } + + [LoggerMessage(LoggingEventIds.ErrorResponseContentBlockEvent, LogLevel.Information, "Error response [{requestId}] content block {blockNumber}: {content}", SkipEnabledCheck = true, EventName = "ErrorResponseContentBlock")] + private static partial void ErrorResponseContentBlock(ILogger logger, string requestId, int blockNumber, byte[] content); + + [LoggerMessage(LoggingEventIds.ErrorResponseContentTextBlockEvent, LogLevel.Information, "Error response [{requestId}] content block {blockNumber}: {content}", SkipEnabledCheck = true, EventName = "ErrorResponseContentTextBlock")] + private static partial void ErrorResponseContentTextBlock(ILogger logger, string requestId, int blockNumber, string content); + + #endregion +} diff --git a/sdk/core/System.ClientModel/src/Internal/Logging/PipelineMessageSanitizer.cs b/sdk/core/System.ClientModel/src/Internal/Logging/PipelineMessageSanitizer.cs new file mode 100644 index 000000000000..606c7e670de9 --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/Logging/PipelineMessageSanitizer.cs @@ -0,0 +1,206 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace System.ClientModel.Internal; + +internal class PipelineMessageSanitizer +{ + private const string LogAllValue = "*"; + private readonly bool _logAllHeaders; + private readonly bool _logFullQueries; + private readonly string _redactedPlaceholder; + + [ThreadStatic] + private static StringBuilder? s_cachedStringBuilder; + private const int MaxCachedStringBuilderCapacity = 1024; + + internal readonly HashSet _allowedQueryParameters; //internal for testing + internal readonly HashSet _allowedHeaders; + + public PipelineMessageSanitizer(HashSet allowedQueryParameters, HashSet allowedHeaders, string redactedPlaceholder = "REDACTED") + { + _logAllHeaders = allowedHeaders.Contains(LogAllValue); + _logFullQueries = allowedQueryParameters.Contains(LogAllValue); + + _allowedQueryParameters = allowedQueryParameters; + _redactedPlaceholder = redactedPlaceholder; + _allowedHeaders = allowedHeaders; + } + + public string SanitizeHeader(string name, string value) + { + if (_logAllHeaders || _allowedHeaders.Contains(name)) + { + return value; + } + + return _redactedPlaceholder; + } + + public bool ShouldSanitizeHeaderValue(string name) + { + if (_logAllHeaders || _allowedHeaders.Contains(name)) + { + return false; + } + + return true; + } + + public string SanitizeUrl(string url) + { + if (_logFullQueries) + { + return url; + } + +#if NET6_0_OR_GREATER + int indexOfQuerySeparator = url.IndexOf('?', StringComparison.Ordinal); +#else + int indexOfQuerySeparator = url.IndexOf('?'); +#endif + + if (indexOfQuerySeparator == -1) + { + return url; + } + + // PERF: Avoid allocations in this heavily-used method: + // 1. Use ReadOnlySpan to avoid creating substrings. + // 2. Defer creating a StringBuilder until absolutely necessary. + // 3. Use a rented StringBuilder to avoid allocating a new one + // each time. + + // Create the StringBuilder only when necessary (when we encounter + // a query parameter that needs to be redacted) + StringBuilder? stringBuilder = null; + + // Keeps track of the number of characters we've processed so far + // so that, if we need to create a StringBuilder, we know how many + // characters to copy over from the original URL. + int lengthSoFar = indexOfQuerySeparator + 1; + + ReadOnlySpan query = url.AsSpan(indexOfQuerySeparator + 1); // +1 to skip the '?' + + while (query.Length > 0) + { + int endOfParameterValue = query.IndexOf('&'); + int endOfParameterName = query.IndexOf('='); + bool noValue = false; + + // Check if we have parameter without value + if ((endOfParameterValue == -1 && endOfParameterName == -1) || + (endOfParameterValue != -1 && (endOfParameterName == -1 || endOfParameterName > endOfParameterValue))) + { + endOfParameterName = endOfParameterValue; + noValue = true; + } + + if (endOfParameterName == -1) + { + endOfParameterName = query.Length; + } + + if (endOfParameterValue == -1) + { + endOfParameterValue = query.Length; + } + else + { + // include the separator + endOfParameterValue++; + } + + ReadOnlySpan parameterName = query.Slice(0, endOfParameterName); + + bool isAllowed = false; + foreach (string name in _allowedQueryParameters) + { + if (parameterName.Equals(name.AsSpan(), StringComparison.OrdinalIgnoreCase)) + { + isAllowed = true; + break; + } + } + + int valueLength = endOfParameterValue; + int nameLength = endOfParameterName; + + if (isAllowed || noValue) + { + if (stringBuilder is null) + { + lengthSoFar += valueLength; + } + else + { + AppendReadOnlySpan(stringBuilder, query.Slice(0, valueLength)); + } + } + else + { + // Encountered a query value that needs to be redacted. + // Create the StringBuilder if we haven't already. + stringBuilder ??= RentStringBuilder(url.Length).Append(url, 0, lengthSoFar); + + AppendReadOnlySpan(stringBuilder, query.Slice(0, nameLength)) + .Append('=') + .Append(_redactedPlaceholder); + + if (query[endOfParameterValue - 1] == '&') + { + stringBuilder.Append('&'); + } + } + + query = query.Slice(valueLength); + } + + return stringBuilder is null ? url : ToStringAndReturnStringBuilder(stringBuilder); + + static StringBuilder AppendReadOnlySpan(StringBuilder builder, ReadOnlySpan span) + { +#if NET6_0_OR_GREATER + return builder.Append(span); +#else + foreach (char c in span) + { + builder.Append(c); + } + + return builder; +#endif + } + } + + private static StringBuilder RentStringBuilder(int capacity) + { + if (capacity <= MaxCachedStringBuilderCapacity) + { + StringBuilder? builder = s_cachedStringBuilder; + if (builder is not null && builder.Capacity >= capacity) + { + s_cachedStringBuilder = null; + return builder; + } + } + + return new StringBuilder(capacity); + } + + private static string ToStringAndReturnStringBuilder(StringBuilder builder) + { + string result = builder.ToString(); + if (builder.Capacity <= MaxCachedStringBuilderCapacity) + { + s_cachedStringBuilder = builder.Clear(); + } + + return result; + } +} diff --git a/sdk/core/System.ClientModel/src/Internal/Logging/PipelineRetryLogger.cs b/sdk/core/System.ClientModel/src/Internal/Logging/PipelineRetryLogger.cs new file mode 100644 index 000000000000..86e1b6160e64 --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/Logging/PipelineRetryLogger.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using Microsoft.Extensions.Logging; + +namespace System.ClientModel.Internal; + +internal partial class PipelineRetryLogger +{ + private readonly ILogger? _logger; + + public PipelineRetryLogger(ILoggerFactory? loggerFactory) + { + _logger = loggerFactory?.CreateLogger() ?? null; + } + + public void LogRequestRetrying(string? requestId, int retryNumber, double seconds) + { + if (_logger is not null) + { + RequestRetrying(_logger, requestId, retryNumber, seconds); + } + else + { + ClientModelEventSource.Log.RequestRetrying(requestId, retryNumber, seconds); + } + } + + [LoggerMessage(LoggingEventIds.RequestRetryingEvent, LogLevel.Information, "Request [{requestId}] attempt number {retryNumber} took {seconds:00.0}s", EventName = "RequestRetrying")] + private static partial void RequestRetrying(ILogger logger, string? requestId, int retryNumber, double seconds); +} diff --git a/sdk/core/System.ClientModel/src/Internal/Logging/PipelineTransportLogger.cs b/sdk/core/System.ClientModel/src/Internal/Logging/PipelineTransportLogger.cs new file mode 100644 index 000000000000..dea86edb50be --- /dev/null +++ b/sdk/core/System.ClientModel/src/Internal/Logging/PipelineTransportLogger.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Diagnostics.Tracing; +using Microsoft.Extensions.Logging; + +namespace System.ClientModel.Internal; + +internal partial class PipelineTransportLogger +{ + private readonly ILogger? _logger; + + public PipelineTransportLogger(ILoggerFactory? loggerFactory) + { + _logger = loggerFactory?.CreateLogger() ?? null; + } + + #region Delay Response + + public void LogResponseDelay(string requestId, double seconds) + { + if (_logger is not null) + { + ResponseDelay(_logger, requestId, seconds); + } + else + { + ClientModelEventSource.Log.ResponseDelay(requestId, seconds); + } + } + + [LoggerMessage(LoggingEventIds.ResponseDelayEvent, LogLevel.Warning, "Response [{requestId}] took {seconds:00.0}s", EventName = "ResponseDelay")] + private static partial void ResponseDelay(ILogger logger, string requestId, double seconds); + + #endregion + + #region Exception Response + + public void LogExceptionResponse(string requestId, Exception exception) + { + if (_logger is not null) + { + ExceptionResponse(_logger, requestId, exception); + } + else if (ClientModelEventSource.Log.IsEnabled(EventLevel.Informational, EventKeywords.None)) + { + ClientModelEventSource.Log.ExceptionResponse(requestId, exception.ToString()); + } + } + + [LoggerMessage(LoggingEventIds.ExceptionResponseEvent, LogLevel.Information, "Request [{requestId}] exception occurred.", EventName = "ExceptionResponse")] + private static partial void ExceptionResponse(ILogger logger, string requestId, Exception exception); + + #endregion +} diff --git a/sdk/core/System.ClientModel/src/Message/PipelineRequest.cs b/sdk/core/System.ClientModel/src/Message/PipelineRequest.cs index eaf15c1bfc05..771c42ec855e 100644 --- a/sdk/core/System.ClientModel/src/Message/PipelineRequest.cs +++ b/sdk/core/System.ClientModel/src/Message/PipelineRequest.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System.Diagnostics; + namespace System.ClientModel.Primitives; /// @@ -68,6 +70,11 @@ public BinaryContent? Content /// protected abstract BinaryContent? ContentCore { get; set; } + /// + /// The client request id to include in log entries. + /// + internal string? ClientRequestId { get; set; } + /// public abstract void Dispose(); } diff --git a/sdk/core/System.ClientModel/src/Options/ClientLoggingOptions.cs b/sdk/core/System.ClientModel/src/Options/ClientLoggingOptions.cs new file mode 100644 index 000000000000..171a57417eba --- /dev/null +++ b/sdk/core/System.ClientModel/src/Options/ClientLoggingOptions.cs @@ -0,0 +1,281 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Internal; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +namespace System.ClientModel.Primitives; + +/// +/// Exposes client options for logging within a . +/// +public class ClientLoggingOptions +{ + private bool _frozen; + + private bool? _enableLogging; + private bool? _enableMessageLogging; + private bool? _enableMessageContentLogging; + private int? _messageContentSizeLimit; + private ILoggerFactory? _loggerFactory; + private PipelineMessageSanitizer? _sanitizer; + private ChangeTrackingStringList? _allowedHeaderNames; + private ChangeTrackingStringList? _allowedQueryParameters; + + // These values are similar to the default values in Azure.Core.DiagnosticsOptions and both + // should be kept in sync. When updating, update the default values in both classes. + private static readonly HashSet s_defaultAllowedHeaderNames = ["traceparent", + "Accept", + "Cache-Control", + "Connection", + "Content-Length", + "Content-Type", + "Date", + "ETag", + "Expires", + "If-Match", + "If-Modified-Since", + "If-None-Match", + "If-Unmodified-Since", + "Last-Modified", + "Pragma", + "Retry-After", + "Server", + "Transfer-Encoding", + "User-Agent", + "WWW-Authenticate" ]; + private static readonly HashSet s_defaultAllowedQueryParameters = ["api-version"]; + private static readonly PipelineMessageSanitizer s_defaultSanitizer = new(s_defaultAllowedQueryParameters, s_defaultAllowedHeaderNames); + + internal const bool DefaultEnableLogging = true; + internal const bool DefaultEnableMessageContentLogging = false; + internal const int DefaultMessageContentSizeLimitBytes = 4 * 1024; + internal const double RequestTooLongSeconds = 3.0; // sec + + /// + /// Gets or sets the implementation of to use to + /// create instances for logging. + /// + /// If an ILoggerFactory is not provided, logs will be written to Event Source + /// instead. If an ILoggerFactory is provided, logs will be written to ILogger only and not + /// Event Source. + /// Defaults to . + public ILoggerFactory? LoggerFactory + { + get => _loggerFactory; + set + { + AssertNotFrozen(); + + _loggerFactory = value; + } + } + + /// + /// Gets or sets value indicating if logging should be enabled in this client pipeline. + /// + /// Defaults to null. If null, this value will be treated as true. + public bool? EnableLogging + { + get => _enableLogging; + set + { + AssertNotFrozen(); + + _enableLogging = value; + } + } + + /// + /// Gets or sets value indicating if request and response uri and header information should be logged. + /// + /// Defaults to null. If null, the value + /// of will be used instead. + public bool? EnableMessageLogging + { + get => _enableMessageLogging; + set + { + AssertNotFrozen(); + + _enableMessageLogging = value; + } + } + + /// + /// Gets or sets value indicating if request and response content should be logged. + /// + /// Defaults to null. If null, this value will be treated as false. + public bool? EnableMessageContentLogging + { + get => _enableMessageContentLogging; + set + { + AssertNotFrozen(); + + _enableMessageContentLogging = value; + } + } + + /// + /// Gets or sets value indicating maximum size of content to log in bytes. + /// + /// Defaults to null. If null, this value will be treated as + public int? MessageContentSizeLimit + { + get => _messageContentSizeLimit; + set + { + AssertNotFrozen(); + + _messageContentSizeLimit = value; + } + } + + /// + /// Gets or sets a list of header names that are not redacted during logging. + /// + /// Defaults to a list of common header names that do not + /// typically hold sensitive information. + public IList AllowedHeaderNames + { + get + { + if (!_frozen) + { + if (_allowedHeaderNames is null) + { + var changeList = new ChangeTrackingStringList(s_defaultAllowedHeaderNames); + _allowedHeaderNames = changeList; + } + return _allowedHeaderNames; + } + else + { + if (_allowedHeaderNames is null) + { + // If this instance is frozen still allow read-only access to the defaults by + // creating a copy of the default allowed headers and freezing it. This + // avoids copying the default array and allocating the change tracking list unless necessary. + _allowedHeaderNames = new ChangeTrackingStringList(s_defaultAllowedHeaderNames); + _allowedHeaderNames.Freeze(); + } + return _allowedHeaderNames; + } + } + } + + /// + /// Gets or sets a list of query parameter names that are not redacted during logging. + /// + /// Defaults to a list of common query parameters that do not + /// typically hold sensitive information. + public IList AllowedQueryParameters + { + get + { + if (!_frozen) + { + if (_allowedQueryParameters is null) + { + var changeList = new ChangeTrackingStringList(s_defaultAllowedQueryParameters); + _allowedQueryParameters = changeList; + } + return _allowedQueryParameters; + } + else + { + if (_allowedQueryParameters is null) + { + // If this instance is frozen still allow read-only access to the defaults by + // creating a copy of the default allowed query parameters and freezing it. This + // avoids copying the default array and allocating the change tracking list unless necessary. + _allowedQueryParameters = new ChangeTrackingStringList(s_defaultAllowedQueryParameters); + _allowedQueryParameters.Freeze(); + } + return _allowedQueryParameters; + } + } + } + + /// + /// Freeze this instance of . After this method + /// has been called, any attempt to set properties on the instance or call + /// methods that would change its state will throw . + /// + public virtual void Freeze() + { + _frozen = true; + if (_allowedHeaderNames is not null) + { + _allowedHeaderNames.Freeze(); + } + if (_allowedQueryParameters is not null) + { + _allowedQueryParameters.Freeze(); + } + } + + /// + /// Assert that has not been called on this + /// instance. + /// + /// Thrown when an attempt is + /// made to change the state of this instance + /// after has been called. + protected void AssertNotFrozen() + { + if (_frozen) + { + throw new InvalidOperationException("Cannot change a ClientLoggingOptions instance after the ClientPipeline is created."); + } + } + + internal void ValidateOptions() + { + if (EnableLogging == false + && (EnableMessageLogging == true || EnableMessageContentLogging == true)) + { + throw new InvalidOperationException("HTTP Message logging cannot be enabled when client-wide logging is disabled."); + } + if (EnableMessageLogging == false + && EnableMessageContentLogging == true) + { + throw new InvalidOperationException("HTTP Message content logging cannot be enabled when HTTP message logging is disabled."); + } + } + + internal PipelineMessageSanitizer GetPipelineMessageSanitizer() + { + Console.WriteLine($"Allowed header names:{_allowedHeaderNames}"); + if (HeaderListIsDefault && QueryParameterListIsDefault) + { + return s_defaultSanitizer; + } + HashSet headers = _allowedHeaderNames == null ? s_defaultAllowedHeaderNames : new HashSet(_allowedHeaderNames, StringComparer.InvariantCultureIgnoreCase); + HashSet queryParams = _allowedQueryParameters == null ? s_defaultAllowedQueryParameters : new HashSet(_allowedQueryParameters, StringComparer.InvariantCultureIgnoreCase); + + _sanitizer ??= new PipelineMessageSanitizer(queryParams, headers); + + return _sanitizer; + } + + internal bool AddMessageLoggingPolicy => EnableMessageLogging ?? EnableLogging ?? DefaultEnableLogging; + + internal bool UseDefaultClientWideLogging => LoggerFactory == null + && EnableLogging == null; + + internal bool AddDefaultMessageLoggingPolicy => EnableLogging == null + && MessageContentSizeLimit == null + && EnableMessageLogging == null + && EnableMessageContentLogging == null + && LoggerFactory == null + && HeaderListIsDefault + && QueryParameterListIsDefault; + + private bool HeaderListIsDefault => _allowedHeaderNames == null || !_allowedHeaderNames.HasChanged; + private bool QueryParameterListIsDefault => _allowedQueryParameters == null || !_allowedQueryParameters.HasChanged; +} diff --git a/sdk/core/System.ClientModel/src/Options/ClientPipelineOptions.cs b/sdk/core/System.ClientModel/src/Options/ClientPipelineOptions.cs index e8289a030228..6cb2f7ce475a 100644 --- a/sdk/core/System.ClientModel/src/Options/ClientPipelineOptions.cs +++ b/sdk/core/System.ClientModel/src/Options/ClientPipelineOptions.cs @@ -20,8 +20,10 @@ public class ClientPipelineOptions private bool _frozen; private PipelinePolicy? _retryPolicy; + private PipelinePolicy? _loggingPolicy; private PipelineTransport? _transport; private TimeSpan? _timeout; + private ClientLoggingOptions? _loggingOptions; #region Pipeline creation: Overrides of default pipeline policies @@ -44,6 +46,25 @@ public PipelinePolicy? RetryPolicy } } + /// + /// Gets or sets the to be used by the + /// for logging. + /// + /// + /// In most cases, this property will be set to an instance of + /// . + /// + public PipelinePolicy? MessageLoggingPolicy + { + get => _loggingPolicy; + set + { + AssertNotFrozen(); + + _loggingPolicy = value; + } + } + /// /// Gets or sets the to be used by the /// for sending and receiving HTTP messages. @@ -81,6 +102,21 @@ public TimeSpan? NetworkTimeout } } + /// + /// The options to be used to configure logging within the + /// . + /// + public ClientLoggingOptions? ClientLoggingOptions + { + get => _loggingOptions; + set + { + AssertNotFrozen(); + + _loggingOptions = value; + } + } + #endregion #region Pipeline creation: User-specified policies @@ -161,7 +197,11 @@ internal static PipelinePolicy[] AddPolicy(PipelinePolicy policy, PipelinePolicy /// instance or call methods that would change its state will throw /// . /// - public virtual void Freeze() => _frozen = true; + public virtual void Freeze() + { + _frozen = true; + _loggingOptions?.Freeze(); + } /// /// Assert that has not been called on this @@ -177,4 +217,39 @@ protected void AssertNotFrozen() throw new InvalidOperationException("Cannot change a ClientPipelineOptions instance after it has been used to create a ClientPipeline."); } } + + #region Helpers + + internal HttpClientPipelineTransport GetHttpClientPipelineTransport() + { + if (_loggingOptions == null || _loggingOptions.UseDefaultClientWideLogging) + { + return HttpClientPipelineTransport.Shared; + } + return new HttpClientPipelineTransport(null, _loggingOptions.EnableLogging ?? ClientLoggingOptions.DefaultEnableLogging, _loggingOptions.LoggerFactory); + } + + internal ClientRetryPolicy GetClientRetryPolicy() + { + if (_loggingOptions == null || _loggingOptions.UseDefaultClientWideLogging) + { + return ClientRetryPolicy.Default; + } + return new ClientRetryPolicy(ClientRetryPolicy.DefaultMaxRetries, + _loggingOptions.EnableLogging ?? ClientLoggingOptions.DefaultEnableLogging, + _loggingOptions.LoggerFactory); + } + + internal bool AddMessageLoggingPolicy => _loggingOptions?.AddMessageLoggingPolicy ?? true; + + internal MessageLoggingPolicy GetMessageLoggingPolicy() + { + if (_loggingOptions == null || _loggingOptions.AddDefaultMessageLoggingPolicy) + { + return System.ClientModel.Primitives.MessageLoggingPolicy.Default; + } + return new MessageLoggingPolicy(_loggingOptions); + } + + #endregion } diff --git a/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.RequestOptionsProcessor.cs b/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.RequestOptionsProcessor.cs index 4dbeb87f3bfc..7adc63a1df9b 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.RequestOptionsProcessor.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.RequestOptionsProcessor.cs @@ -185,7 +185,7 @@ private bool TryGetCustomPerTryPolicy(int index, out PipelinePolicy policy) private bool TryGetFixedPerTransportPolicy(int index, out PipelinePolicy policy) { - if (index < _perTryIndex + _customPerCallPolicies.Length + _customPerTryPolicies.Length) + if (index < _beforeTransportIndex + _customPerCallPolicies.Length + _customPerTryPolicies.Length) { policy = _fixedPolicies.Span[index - (_customPerCallPolicies.Length + _customPerTryPolicies.Length)]; return true; @@ -197,7 +197,7 @@ private bool TryGetFixedPerTransportPolicy(int index, out PipelinePolicy policy) private bool TryGetCustomBeforeTransportPolicy(int index, out PipelinePolicy policy) { - if (index < _perTryIndex + + if (index < _beforeTransportIndex + _customPerCallPolicies.Length + _customPerTryPolicies.Length + _customBeforeTransportPolicies.Length) diff --git a/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.cs b/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.cs index bae030b7ec33..4e3dc1a1840b 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/ClientPipeline.cs @@ -5,7 +5,10 @@ using System.Collections; using System.Collections.Generic; using System.Diagnostics; +using System.Linq; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; namespace System.ClientModel.Primitives; @@ -26,10 +29,11 @@ public sealed partial class ClientPipeline private readonly ReadOnlyMemory _policies; private readonly PipelineTransport _transport; + private readonly bool _enableLogging; private readonly TimeSpan _networkTimeout; - private ClientPipeline(ReadOnlyMemory policies, TimeSpan networkTimeout, int perCallIndex, int perTryIndex, int beforeTransportIndex) + private ClientPipeline(ReadOnlyMemory policies, TimeSpan networkTimeout, int perCallIndex, int perTryIndex, int beforeTransportIndex, bool enableLogging) { if (policies.Span[policies.Length - 1] is not PipelineTransport) { @@ -47,6 +51,7 @@ private ClientPipeline(ReadOnlyMemory policies, TimeSpan network _beforeTransportIndex = beforeTransportIndex; _networkTimeout = networkTimeout; + _enableLogging = enableLogging; } #region Factory methods for creating a pipeline instance @@ -101,6 +106,7 @@ public static ClientPipeline Create( Argument.AssertNotNull(options, nameof(options)); options.Freeze(); + options.ClientLoggingOptions?.ValidateOptions(); // Add length of client-specific policies. int pipelineLength = perCallPolicies.Length + perTryPolicies.Length + beforeTransportPolicies.Length; @@ -111,6 +117,7 @@ public static ClientPipeline Create( pipelineLength += options.BeforeTransportPolicies?.Length ?? 0; pipelineLength++; // for retry policy + pipelineLength += options.AddMessageLoggingPolicy ? 1 : 0; // for message logging policy pipelineLength++; // for transport PipelinePolicy[] policies = new PipelinePolicy[pipelineLength]; @@ -130,7 +137,7 @@ public static ClientPipeline Create( int perCallIndex = index; // Add retry policy. - policies[index++] = options.RetryPolicy ?? ClientRetryPolicy.Default; + policies[index++] = options.RetryPolicy ?? options.GetClientRetryPolicy(); // Per try policies come after the retry policy. perTryPolicies.CopyTo(policies.AsSpan(index)); @@ -144,6 +151,13 @@ public static ClientPipeline Create( int perTryIndex = index; + // Add logging policy just before the transport. + + if (options.AddMessageLoggingPolicy) + { + policies[index++] = options.MessageLoggingPolicy ?? options.GetMessageLoggingPolicy(); + } + // Before transport policies come before the transport. beforeTransportPolicies.CopyTo(policies.AsSpan(index)); index += beforeTransportPolicies.Length; @@ -157,11 +171,13 @@ public static ClientPipeline Create( int beforeTransportIndex = index; // Add the transport. - policies[index++] = options.Transport ?? HttpClientPipelineTransport.Shared; + policies[index++] = options.Transport ?? options.GetHttpClientPipelineTransport(); + + bool enableLogging = options.ClientLoggingOptions?.EnableLogging ?? ClientLoggingOptions.DefaultEnableLogging; return new ClientPipeline(policies, options.NetworkTimeout ?? DefaultNetworkTimeout, - perCallIndex, perTryIndex, beforeTransportIndex); + perCallIndex, perTryIndex, beforeTransportIndex, enableLogging); } #endregion @@ -194,8 +210,10 @@ public PipelineMessage CreateMessage() public void Send(PipelineMessage message) { Argument.AssertNotNull(message, nameof(message)); + message.Request.ClientRequestId = Activity.Current?.Id ?? Guid.NewGuid().ToString(); IReadOnlyList policies = GetProcessor(message); + policies[0].Process(message, policies, 0); } @@ -215,8 +233,10 @@ public void Send(PipelineMessage message) public async ValueTask SendAsync(PipelineMessage message) { Argument.AssertNotNull(message, nameof(message)); + message.Request.ClientRequestId = Activity.Current?.Id ?? Guid.NewGuid().ToString(); IReadOnlyList policies = GetProcessor(message); + await policies[0].ProcessAsync(message, policies, 0).ConfigureAwait(false); } diff --git a/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs b/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs index 3e9be2399624..f8e61e9a46d4 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/ClientRetryPolicy.cs @@ -7,6 +7,8 @@ using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; namespace System.ClientModel.Primitives; @@ -22,20 +24,34 @@ public class ClientRetryPolicy : PipelinePolicy /// public static ClientRetryPolicy Default { get; } = new ClientRetryPolicy(); - private const int DefaultMaxRetries = 3; private static readonly TimeSpan DefaultInitialDelay = TimeSpan.FromSeconds(0.8); private readonly int _maxRetries; private readonly TimeSpan _initialDelay; + private readonly PipelineRetryLogger? _retryLogger; + + internal const int DefaultMaxRetries = 3; + + /// + /// Creates a new instance of the class. + /// + /// The maximum number of retries to attempt. + public ClientRetryPolicy(int maxRetries = DefaultMaxRetries) : this(maxRetries, ClientLoggingOptions.DefaultEnableLogging, default) + { + } /// /// Creates a new instance of the class. /// /// The maximum number of retries to attempt. - public ClientRetryPolicy(int maxRetries = DefaultMaxRetries) + /// If client-wide logging is enabled for this pipeline. + /// The to use to create an instance for logging. + /// If one is not provided, logs are written to Event Source by default. + public ClientRetryPolicy(int maxRetries, bool enableLogging, ILoggerFactory? loggerFactory) { _maxRetries = maxRetries; _initialDelay = DefaultInitialDelay; + _retryLogger = enableLogging ? new PipelineRetryLogger(loggerFactory) : null; } /// @@ -53,6 +69,7 @@ private async ValueTask ProcessSyncOrAsync(PipelineMessage message, IReadOnlyLis while (true) { Exception? thisTryException = null; + var before = Stopwatch.GetTimestamp(); if (async) { @@ -91,6 +108,9 @@ private async ValueTask ProcessSyncOrAsync(PipelineMessage message, IReadOnlyLis OnRequestSent(message); } + var after = Stopwatch.GetTimestamp(); + double elapsed = (after-before) / (double)Stopwatch.Frequency; + bool shouldRetry = async ? await ShouldRetryInternalAsync(message, thisTryException).ConfigureAwait(false) : ShouldRetryInternal(message, thisTryException); @@ -116,6 +136,8 @@ await ShouldRetryInternalAsync(message, thisTryException).ConfigureAwait(false) message.RetryCount++; OnTryComplete(message); + _retryLogger?.LogRequestRetrying(message.Request.ClientRequestId ?? string.Empty, message.RetryCount, elapsed); + continue; } diff --git a/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.cs b/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.cs index 015847915920..bc7bb5378cbf 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/HttpClientPipelineTransport.cs @@ -6,6 +6,7 @@ using System.Net.Http; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; namespace System.ClientModel.Primitives; @@ -40,11 +41,24 @@ public HttpClientPipelineTransport() : this(_sharedDefaultClient) /// The that this transport /// instance will use to send and receive HTTP requests and responses. /// - public HttpClientPipelineTransport(HttpClient client) + public HttpClientPipelineTransport(HttpClient client) : this(client, ClientLoggingOptions.DefaultEnableLogging, null) { Argument.AssertNotNull(client, nameof(client)); + } - _httpClient = client; + /// + /// Create a new instance of that + /// uses the provided . + /// + /// The that this transport + /// instance will use to send and receive HTTP requests and responses. If no + /// is passed, a default shared client will be used. + /// + /// If client-wide logging is enabled for this pipeline. + /// The to use to create an instance for logging. + public HttpClientPipelineTransport(HttpClient? client, bool enableLogging, ILoggerFactory? loggerFactory) : base(enableLogging, loggerFactory) + { + _httpClient = client ?? _sharedDefaultClient; } private static HttpClient CreateDefaultClient() diff --git a/sdk/core/System.ClientModel/src/Pipeline/MessageLoggingPolicy.cs b/sdk/core/System.ClientModel/src/Pipeline/MessageLoggingPolicy.cs new file mode 100644 index 000000000000..1f419a765ae5 --- /dev/null +++ b/sdk/core/System.ClientModel/src/Pipeline/MessageLoggingPolicy.cs @@ -0,0 +1,372 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Internal; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.Tracing; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace System.ClientModel.Primitives; + +/// +/// A used by a to +/// log request and response information. +/// +public class MessageLoggingPolicy : PipelinePolicy +{ + /// + /// The instance used by a default + /// . + /// + public static MessageLoggingPolicy Default { get; } = new MessageLoggingPolicy(); + + private readonly ClientLoggingOptions _loggingOptions; + private PipelineMessageLogger? _messageLogger; + private readonly string _clientAssembly = typeof(MessageLoggingPolicy).Assembly.GetName().Name!; + + private bool _enableMessageContentLogging => _loggingOptions.EnableMessageContentLogging ?? ClientLoggingOptions.DefaultEnableMessageContentLogging; + private int _maxLength => _loggingOptions.MessageContentSizeLimit ?? ClientLoggingOptions.DefaultMessageContentSizeLimitBytes; + + /// + /// Creates a new instance of the class. + /// + /// The user-provided logging options object. + public MessageLoggingPolicy(ClientLoggingOptions? options = default) + { + _loggingOptions = options ?? new(); + } + + /// + public sealed override void Process(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) => + ProcessSyncOrAsync(message, pipeline, currentIndex, async: false).EnsureCompleted(); + + /// + public sealed override async ValueTask ProcessAsync(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) => + await ProcessSyncOrAsync(message, pipeline, currentIndex, async: true).ConfigureAwait(false); + + private async ValueTask ProcessSyncOrAsync(PipelineMessage message, IReadOnlyList pipeline, int currentIndex, bool async) + { + _messageLogger ??= new PipelineMessageLogger(_loggingOptions.GetPipelineMessageSanitizer(), _loggingOptions.LoggerFactory); + + // EventLevel.Warning / LogLevel.Warning is the highest level logged by this policy. + // PipelineMessageLogger.IsEnabled checks to see: if an ILogger was provided - ensure it is enabled for at least LogLevel.Warning OR if + // an ILogger was not provided, see if an EventListener exists that is actively listening for at least EventLevel.Warning. + // If nothing is listening for logs then this policy immediately passes control to the next policy in the pipeline and returns when control + // is passed back. This avoids any performance hit from logging when no one is listening/collecting logs. + if (!_messageLogger.IsEnabled(LogLevel.Warning, EventLevel.Warning)) + { + if (async) + { + await ProcessNextAsync(message, pipeline, currentIndex).ConfigureAwait(false); + } + else + { + ProcessNext(message, pipeline, currentIndex); + } + return; + } + + PipelineRequest request = message.Request; + + string requestId = message.Request.ClientRequestId ?? string.Empty; + + _messageLogger.LogRequest(requestId, request, _clientAssembly); + + if (_enableMessageContentLogging && request.Content != null && _messageLogger.IsEnabled(LogLevel.Debug, EventLevel.Verbose)) + { + // Convert binary content to bytes + using var memoryStream = new MaxLengthStream(_maxLength); + if (async) + { + await request.Content.WriteToAsync(memoryStream, message.CancellationToken).ConfigureAwait(false); + } + else + { + request.Content.WriteTo(memoryStream, message.CancellationToken); + } + byte[] bytes = memoryStream.ToArray(); + + Encoding? requestTextEncoding = null; + // Try to extract a text encoding from the headers + if (request.Headers.TryGetValue("Content-Type", out var contentType) && contentType != null) + { + ContentTypeUtilities.TryGetTextEncoding(contentType, out requestTextEncoding); + } + + _messageLogger.LogRequestContent(requestId, bytes, requestTextEncoding); + } + + var before = Stopwatch.GetTimestamp(); + + // Any exceptions thrown are logged in the transport + if (async) + { + await ProcessNextAsync(message, pipeline, currentIndex).ConfigureAwait(false); + } + else + { + ProcessNext(message, pipeline, currentIndex); + } + + var after = Stopwatch.GetTimestamp(); + + PipelineResponse response = message.Response!; + + double elapsed = (after - before) / (double)Stopwatch.Frequency; + + if (response.IsError) + { + _messageLogger.LogErrorResponse(requestId, response, elapsed); + } + else + { + _messageLogger.LogResponse(requestId, response, elapsed); + } + + if (_enableMessageContentLogging && response.ContentStream != null && _messageLogger.IsEnabled(LogLevel.Information, EventLevel.Informational)) + { + Encoding? responseTextEncoding = null; + + if (response.Headers.TryGetValue("Content-Type", out var contentType) && contentType != null) + { + ContentTypeUtilities.TryGetTextEncoding(contentType, out responseTextEncoding); + } + + if (message.BufferResponse || response.ContentStream.CanSeek) + { + byte[]? responseBytes; + if (message.BufferResponse) + { + // Content is buffered, so log the first _maxLength bytes + ReadOnlyMemory contentAsMemory = response.Content.ToMemory(); + var length = Math.Min(contentAsMemory.Length, _maxLength); + responseBytes = contentAsMemory.Span.Slice(0, length).ToArray(); + } + else + { + responseBytes = new byte[_maxLength]; + response.ContentStream.Read(responseBytes, 0, _maxLength); + response.ContentStream.Seek(0, SeekOrigin.Begin); + } + + if (response.IsError) + { + _messageLogger.LogErrorResponseContent(requestId, responseBytes, responseTextEncoding); + } + else + { + _messageLogger.LogResponseContent(requestId, responseBytes, responseTextEncoding); + } + } + else + { + response.ContentStream = new LoggingStream(_messageLogger, requestId, _maxLength, response.ContentStream, response.IsError, responseTextEncoding); + } + } + } + + #region MaxLengthStream + private class MaxLengthStream : MemoryStream + { + private int _bytesLeft; + + public MaxLengthStream(int maxLength) : base() + { + _bytesLeft = maxLength; + } + + public override void Write(byte[] buffer, int offset, int count) + { + DecrementLength(ref count); + if (count > 0) + { + base.Write(buffer, offset, count); + } + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return count > 0 ? base.WriteAsync(buffer, offset, count, cancellationToken) : Task.CompletedTask; + } + + private void DecrementLength(ref int count) + { + var left = Math.Min(count, _bytesLeft); + count = left; + + _bytesLeft -= count; + } + } + #endregion + + #region LoggingStream + private class LoggingStream : Stream + { + private readonly string _requestId; + private int _remainingBytesToLog; + private readonly Stream _originalStream; + private readonly bool _error; + private readonly Encoding? _textEncoding; + private int _blockNumber; + private readonly PipelineMessageLogger _messageLogger; + + public LoggingStream(PipelineMessageLogger messageLogger, string requestId, int maxLoggedBytes, Stream originalStream, bool error, Encoding? textEncoding) + { + // Should only wrap non-seekable streams + Debug.Assert(!originalStream.CanSeek); + _requestId = requestId; + _remainingBytesToLog = maxLoggedBytes; + _originalStream = originalStream; + _error = error; + _textEncoding = textEncoding; + _messageLogger = messageLogger; + } + + public override int Read(byte[] buffer, int offset, int count) + { + var numBytesRead = _originalStream.Read(buffer, offset, count); + + LogBuffer(buffer, offset, numBytesRead); + + return numBytesRead; + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + var numBytesRead = await _originalStream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + + LogBuffer(buffer, offset, numBytesRead); + + return numBytesRead; + } + +#if !NETSTANDARD2_0 + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + var numBytesRead = await _originalStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); + + LogMemory(buffer, numBytesRead); + + return numBytesRead; + } +#endif + + public override bool CanRead => _originalStream.CanRead; + public override bool CanSeek => _originalStream.CanSeek; + public override long Length => _originalStream.Length; + public override long Position + { + get => _originalStream.Position; + set => _originalStream.Position = value; + } + + // Make this stream readonly + public override bool CanWrite => false; + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException("This stream does not support seek operations."); + } + + // Make this stream readonly + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException("This stream is read-only."); + } + + // Make this stream readonly + public override void SetLength(long value) + { + throw new NotSupportedException("This stream is read-only."); + } + + public override void Flush() + { + _originalStream.Flush(); + } + + public override void Close() + { + _originalStream.Close(); + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + + _originalStream.Dispose(); + } + + #region Helpers + + private void LogMemory(Memory memory, int numBytesReadIntoMemory) + { + // This is intentionally not thread-safe because synchronizing reads + // should be done by the caller. + var bytesToLog = Math.Min(numBytesReadIntoMemory, _remainingBytesToLog); + _remainingBytesToLog -= bytesToLog; + + if (bytesToLog == 0) + { + return; + } + + byte[] bytes = new byte[bytesToLog]; + memory.Slice(0, bytesToLog).Span.CopyTo(bytes); + + if (_error) + { + _messageLogger.LogErrorResponseContentBlock(_requestId, _blockNumber, bytes, _textEncoding); + } + else + { + _messageLogger.LogResponseContentBlock(_requestId, _blockNumber, bytes, _textEncoding); + } + + _blockNumber++; + } + + private void LogBuffer(byte[] buffer, int offset, int numBytesReadIntoBuffer) + { + // This is intentionally not thread-safe because synchronizing reads + // should be done by the caller. + var bytesToLog = Math.Min(numBytesReadIntoBuffer, _remainingBytesToLog); + _remainingBytesToLog -= bytesToLog; + + if (bytesToLog == 0 || buffer == null) + { + return; + } + + byte[] bytes; + if (bytesToLog == numBytesReadIntoBuffer && offset == 0) + { + bytes = buffer; + } + else + { + bytes = new byte[bytesToLog]; + Buffer.BlockCopy(buffer, offset, bytes, 0, bytesToLog); + } + + if (_error) + { + _messageLogger.LogErrorResponseContentBlock(_requestId, _blockNumber, bytes, _textEncoding); + } + else + { + _messageLogger.LogResponseContentBlock(_requestId, _blockNumber, bytes, _textEncoding); + } + + _blockNumber++; + } + #endregion + } + #endregion +} diff --git a/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs b/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs index 6465ea6c1245..c72a774a1445 100644 --- a/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs +++ b/sdk/core/System.ClientModel/src/Pipeline/PipelineTransport.cs @@ -7,6 +7,7 @@ using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; namespace System.ClientModel.Primitives; @@ -16,6 +17,29 @@ namespace System.ClientModel.Primitives; /// public abstract class PipelineTransport : PipelinePolicy { + private readonly PipelineTransportLogger? _pipelineTransportLogger; + + /// + /// Creates a new instance of the class. + /// + protected PipelineTransport() + { + } + + /// + /// Creates a new instance of the class. + /// + /// If client-wide logging is enabled for this pipeline. + /// The to use to create an instance for logging. + /// If one is not provided, logs are written to Event Source by default. + protected PipelineTransport(bool enableLogging, ILoggerFactory? loggerFactory) + { + if (enableLogging) + { + _pipelineTransportLogger = new(loggerFactory); + } + } + #region CreateMessage /// @@ -63,7 +87,17 @@ public PipelineMessage CreateMessage() /// The containing the /// request that was sent and response that was received by the transport. public void Process(PipelineMessage message) - => ProcessSyncOrAsync(message, async: false).EnsureCompleted(); + { + try + { + ProcessSyncOrAsync(message, async: false).EnsureCompleted(); + } + catch (Exception ex) + { + _pipelineTransportLogger?.LogExceptionResponse(message.Request.ClientRequestId ?? string.Empty, ex); + throw; + } + } /// /// Sends the HTTP request contained by @@ -72,7 +106,17 @@ public void Process(PipelineMessage message) /// The containing the /// request that was sent and response that was received by the transport. public async ValueTask ProcessAsync(PipelineMessage message) - => await ProcessSyncOrAsync(message, async: true).ConfigureAwait(false); + { + try + { + await ProcessSyncOrAsync(message, async: true).ConfigureAwait(false); + } + catch (Exception ex) + { + _pipelineTransportLogger?.LogExceptionResponse(message.Request.ClientRequestId ?? string.Empty, ex); + throw; + } + } private async ValueTask ProcessSyncOrAsync(PipelineMessage message, bool async) { @@ -84,6 +128,8 @@ private async ValueTask ProcessSyncOrAsync(PipelineMessage message, bool async) using CancellationTokenSource timeoutTokenSource = CancellationTokenSource.CreateLinkedTokenSource(messageToken); timeoutTokenSource.CancelAfter(networkTimeout); + var before = Stopwatch.GetTimestamp(); + try { message.CancellationToken = timeoutTokenSource.Token; @@ -108,9 +154,17 @@ private async ValueTask ProcessSyncOrAsync(PipelineMessage message, bool async) timeoutTokenSource.CancelAfter(Timeout.Infinite); } + var after = Stopwatch.GetTimestamp(); + double elapsed = (after - before) / (double)Stopwatch.Frequency; + message.AssertResponse(); message.Response!.IsErrorCore = ClassifyResponse(message); + if (elapsed > ClientLoggingOptions.RequestTooLongSeconds) + { + _pipelineTransportLogger?.LogResponseDelay(message.Request.ClientRequestId ?? string.Empty, elapsed); + } + // The remainder of the method handles response content according to // buffering logic specified by value of message.BufferResponse. diff --git a/sdk/core/System.ClientModel/src/System.ClientModel.csproj b/sdk/core/System.ClientModel/src/System.ClientModel.csproj index 291c0ef86906..c802ef3e48b5 100644 --- a/sdk/core/System.ClientModel/src/System.ClientModel.csproj +++ b/sdk/core/System.ClientModel/src/System.ClientModel.csproj @@ -12,8 +12,10 @@ + + \ No newline at end of file diff --git a/sdk/core/System.ClientModel/tests/Options/ClientLoggingOptionsTests.cs b/sdk/core/System.ClientModel/tests/Options/ClientLoggingOptionsTests.cs new file mode 100644 index 000000000000..11fa8e5aa11f --- /dev/null +++ b/sdk/core/System.ClientModel/tests/Options/ClientLoggingOptionsTests.cs @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using Microsoft.Extensions.Logging.Abstractions; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Options +{ + internal class ClientLoggingOptionsTests + { + [Test] + public void NonCollectionPropertiesDefaultToNull() + { + ClientLoggingOptions options = new(); + Assert.AreEqual(null, options.EnableLogging); + Assert.AreEqual(null, options.EnableMessageLogging); + Assert.AreEqual(null, options.EnableMessageContentLogging); + Assert.AreEqual(null, options.MessageContentSizeLimit); + Assert.AreEqual(null, options.LoggerFactory); + } + + [Test] + public void CollectionPropertiesDefaultsAreSet() + { + string[] expectedDefaultAllowedHeaderNames = [ + "traceparent", + "Accept", + "Cache-Control", + "Connection", + "Content-Length", + "Content-Type", + "Date", + "ETag", + "Expires", + "If-Match", + "If-Modified-Since", + "If-None-Match", + "If-Unmodified-Since", + "Last-Modified", + "Pragma", + "Retry-After", + "Server", + "Transfer-Encoding", + "User-Agent", + "WWW-Authenticate" ]; + string[] expectedDefaultAllowedQueryParameters = ["api-version"]; + + ClientLoggingOptions options = new(); + CollectionAssert.AreEquivalent(expectedDefaultAllowedHeaderNames, options.AllowedHeaderNames); + CollectionAssert.AreEquivalent(expectedDefaultAllowedQueryParameters, options.AllowedQueryParameters); + } + + [Test] + public void CanModifyOptions() + { + ClientLoggingOptions options = new(); + + options.MessageContentSizeLimit = 5; + options.LoggerFactory = NullLoggerFactory.Instance; + options.EnableLogging = false; + options.EnableMessageLogging = false; + options.EnableMessageContentLogging = false; + options.AllowedHeaderNames.Add("Hello"); + options.AllowedQueryParameters.Add("Hello"); + + Assert.AreEqual(5, options.MessageContentSizeLimit); + Assert.AreEqual(NullLoggerFactory.Instance, options.LoggerFactory); + Assert.AreEqual(false, options.EnableLogging); + Assert.AreEqual(false, options.EnableMessageLogging); + Assert.AreEqual(false, options.EnableMessageContentLogging); + CollectionAssert.Contains(options.AllowedHeaderNames, "Hello"); + CollectionAssert.Contains(options.AllowedQueryParameters, "Hello"); + } + + [Test] + public void DefaultOptionsFreeze() + { + ClientLoggingOptions options = new(); + + options.Freeze(); + + Assert.Throws(() => options.AllowedHeaderNames.Add("ShouldNotAdd")); + Assert.Throws(() => options.AllowedQueryParameters.Add("ShouldNotAdd")); + Assert.Throws(() => options.EnableLogging = true); + Assert.Throws(() => options.EnableMessageLogging = true); + Assert.Throws(() => options.EnableMessageContentLogging = true); + Assert.Throws(() => options.MessageContentSizeLimit = 5); + Assert.Throws(() => options.LoggerFactory = new NullLoggerFactory()); + } + + [Test] + public void CustomizedOptionsFreeze() + { + ClientLoggingOptions options = new(); + options.MessageContentSizeLimit = 5; + options.LoggerFactory = NullLoggerFactory.Instance; + options.EnableLogging = false; + options.EnableMessageLogging = false; + options.EnableMessageContentLogging = false; + options.AllowedHeaderNames.Add("Hello"); + options.AllowedQueryParameters.Add("Hello"); + + options.Freeze(); + + Assert.Throws(() => options.AllowedHeaderNames.Add("ShouldNotAdd")); + Assert.Throws(() => options.AllowedQueryParameters.Add("ShouldNotAdd")); + Assert.Throws(() => options.EnableLogging = true); + Assert.Throws(() => options.EnableMessageLogging = true); + Assert.Throws(() => options.EnableMessageContentLogging = true); + Assert.Throws(() => options.MessageContentSizeLimit = 10); + Assert.Throws(() => options.LoggerFactory = new NullLoggerFactory()); + } + } +} diff --git a/sdk/core/System.ClientModel/tests/Options/ClientPipelineOptionsTests.cs b/sdk/core/System.ClientModel/tests/Options/ClientPipelineOptionsTests.cs index 8b52090ff5f0..3248f29ca1e2 100644 --- a/sdk/core/System.ClientModel/tests/Options/ClientPipelineOptionsTests.cs +++ b/sdk/core/System.ClientModel/tests/Options/ClientPipelineOptionsTests.cs @@ -149,32 +149,46 @@ public async Task CanAddPoliciesAtAllPositions() public void CannotModifyOptionsAfterFrozen() { ClientPipelineOptions options = new(); + options.ClientLoggingOptions = new(); ClientPipeline pipeline = ClientPipeline.Create(options); Assert.Throws(() => options.RetryPolicy = new MockRetryPolicy()); + Assert.Throws(() + => options.MessageLoggingPolicy = new MockPipelinePolicy()); Assert.Throws(() => options.Transport = new MockPipelineTransport("Transport")); Assert.Throws(() => options.NetworkTimeout = TimeSpan.MinValue); Assert.Throws(() => options.AddPolicy(new ObservablePolicy("A"), PipelinePosition.PerCall)); + Assert.Throws(() + => options.ClientLoggingOptions = new()); + Assert.Throws(() + => options.ClientLoggingOptions.EnableLogging = true); } [Test] public void CannotModifyOptionsAfterExplicitlyFrozen() { ClientPipelineOptions options = new(); + options.ClientLoggingOptions = new(); options.Freeze(); Assert.Throws(() => options.RetryPolicy = new MockRetryPolicy()); + Assert.Throws(() + => options.MessageLoggingPolicy = new MockPipelinePolicy()); Assert.Throws(() => options.Transport = new MockPipelineTransport("Transport")); Assert.Throws(() => options.NetworkTimeout = TimeSpan.MinValue); Assert.Throws(() => options.AddPolicy(new ObservablePolicy("A"), PipelinePosition.PerCall)); + Assert.Throws(() + => options.ClientLoggingOptions = new()); + Assert.Throws(() + => options.ClientLoggingOptions.EnableLogging = true); } #region Helpers diff --git a/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineFunctionalTests.cs b/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineFunctionalTests.cs index 86e3bc1d831f..c226073da57c 100644 --- a/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineFunctionalTests.cs +++ b/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineFunctionalTests.cs @@ -1,22 +1,31 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using Azure.Core.TestFramework; -using ClientModel.Tests.Mocks; -using Microsoft.AspNetCore.Http; -using Moq; -using NUnit.Framework; using System.ClientModel.Primitives; +using System.ClientModel.Tests.TestFramework; using System.Collections.Generic; +using System.Diagnostics.Tracing; using System.IO; +using System.Linq; using System.Threading; using System.Threading.Tasks; +using Azure.Core.TestFramework; +using ClientModel.Tests; +using ClientModel.Tests.Mocks; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using Moq; +using NUnit.Framework; using SyncAsyncTestBase = ClientModel.Tests.SyncAsyncTestBase; namespace System.ClientModel.Tests.Pipeline; public class ClientPipelineFunctionalTests : SyncAsyncTestBase { + private const string LoggingPolicyCategoryName = "System.ClientModel.Primitives.MessageLoggingPolicy"; + private const string PipelineTransportCategoryName = "System.ClientModel.Primitives.PipelineTransport"; + private const string RetryPolicyCategoryName = "System.ClientModel.Primitives.ClientRetryPolicy"; + public ClientPipelineFunctionalTests(bool isAsync) : base(isAsync) { } @@ -348,7 +357,7 @@ public async Task RetriesTimeoutsServerTimeouts() } [Test] - public async Task DoesntRetryClientCancellation() + public async Task DoesNotRetryClientCancellation() { var testDoneTcs = new CancellationTokenSource(); int i = 0; @@ -431,6 +440,274 @@ public async Task RetriesBufferedBodyTimeout() #endregion + #region Test default logging policy behavior + + [Test] + public async Task LogsRequestAndResponseToEventSource() + { + using TestClientEventListener eventListener = new(); + + ClientPipeline pipeline = ClientPipeline.Create(); + + using TestServer testServer = new( + async context => + { + context.Response.StatusCode = 201; + await context.Response.WriteAsync("Hello World!"); + }); + + using PipelineMessage message = pipeline.CreateMessage(); + message.Request.Uri = testServer.Address; + message.BufferResponse = true; + + await pipeline.SendSyncOrAsync(message, IsAsync); + + // Request + EventWrittenEventArgs args = eventListener.SingleEventById(1, e => e.EventSource.Name == "System.ClientModel"); + Assert.AreEqual(EventLevel.Informational, args.Level); + Assert.AreEqual("Request", args.EventName); + + // Response + args = eventListener.SingleEventById(5, e => e.EventSource.Name == "System.ClientModel"); + Assert.AreEqual(EventLevel.Informational, args.Level); + Assert.AreEqual("Response", args.EventName); + Assert.AreEqual(201, args.GetProperty("status")); + + // No other events should have been logged + Assert.AreEqual(2, eventListener.EventData.Count()); + } + + [Test] + public void LogsRequestAndExceptionResponseToEventSource() + { + using TestClientEventListener eventListener = new(); + + ClientPipeline pipeline = ClientPipeline.Create(); + + using TestServer testServer = new( + async context => + { + await context.Response.WriteAsync("Hello World!"); + throw new Exception("Error"); + }); + + using PipelineMessage message = pipeline.CreateMessage(); + message.Request.Uri = testServer.Address; + message.BufferResponse = true; + + Assert.ThrowsAsync(async () => await pipeline.SendSyncOrAsync(message, IsAsync)); + + // Request Events + IEnumerable args = eventListener.EventsById(1); + Assert.AreEqual(4, args.Count()); + + // Exception Response Events + args = eventListener.EventsById(18); + Assert.AreEqual(4, args.Count()); + foreach (EventWrittenEventArgs responseEventArgs in args) + { + Assert.AreEqual(EventLevel.Informational, responseEventArgs.Level); + Assert.AreEqual("ExceptionResponse", responseEventArgs.EventName); + Assert.True((responseEventArgs.GetProperty("exception")).Contains("Exception")); + } + + // 4 request events, 3 request retrying, 4 exception response + Assert.AreEqual(11, eventListener.EventData.Count()); + } + + [Test] + public async Task LogsRequestAndRetryToEventSource() + { + using TestClientEventListener eventListener = new(); + + ClientPipeline pipeline = ClientPipeline.Create(); + + int responseNum = 0; + using TestServer testServer = new( + async context => + { + switch (responseNum) + { + case 0: + context.Response.StatusCode = 429; + await context.Response.WriteAsync("Try again"); + break; + default: + context.Response.StatusCode = 201; + await context.Response.WriteAsync("Success"); + break; + } + responseNum++; + }); + + using PipelineMessage message = pipeline.CreateMessage(); + message.Request.Uri = testServer.Address; + message.BufferResponse = true; + + await pipeline.SendSyncOrAsync(message, IsAsync); + + // Request Events + IEnumerable args = eventListener.EventsById(1); + Assert.AreEqual(2, args.Count()); + + // Retry event + EventWrittenEventArgs arg = eventListener.SingleEventById(10); + Assert.AreEqual("RequestRetrying", arg.EventName); + + // Error response event + arg = eventListener.SingleEventById(8); + Assert.AreEqual("ErrorResponse", arg.EventName); + + // Response event + arg = eventListener.SingleEventById(5); + Assert.AreEqual("Response", arg.EventName); + + // No other events should have been logged + Assert.AreEqual(5, eventListener.EventData.Count()); + } + + [Test] + public async Task LogsRequestAndResponseToILogger() + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientPipelineOptions options = new() { ClientLoggingOptions = new() { LoggerFactory = factory } }; + ClientPipeline pipeline = ClientPipeline.Create(options); + + using TestServer testServer = new( + async context => + { + context.Response.StatusCode = 201; + await context.Response.WriteAsync("Hello World!"); + }); + + using PipelineMessage message = pipeline.CreateMessage(); + message.Request.Uri = testServer.Address; + message.BufferResponse = true; + + await pipeline.SendSyncOrAsync(message, IsAsync); + + // Message logger + TestLogger messageLogger = factory.GetLogger(LoggingPolicyCategoryName); + + // Request + LoggerEvent log = messageLogger.SingleEventById(1); + Assert.AreEqual(LogLevel.Information, log.LogLevel); + Assert.AreEqual("Request", log.EventId.Name); + + // Response + log = messageLogger.SingleEventById(5); + Assert.AreEqual(LogLevel.Information, log.LogLevel); + Assert.AreEqual("Response", log.EventId.Name); + Assert.AreEqual(201, log.GetValueFromArguments("status")); + + // No other events should have been logged + Assert.AreEqual(2, messageLogger.Logs.Count()); + } + + [Test] + public void LogsRequestAndExceptionResponseToILogger() + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientPipelineOptions options = new() { ClientLoggingOptions = new() { LoggerFactory = factory } }; + ClientPipeline pipeline = ClientPipeline.Create(options); + + using TestServer testServer = new( + async context => + { + await context.Response.WriteAsync("Hello World!"); + context.Abort(); + }); + + using PipelineMessage message = pipeline.CreateMessage(); + message.Request.Uri = testServer.Address; + message.BufferResponse = true; + + Assert.ThrowsAsync(async () => await pipeline.SendSyncOrAsync(message, IsAsync)); + + // Message logger + TestLogger messageLogger = factory.GetLogger(LoggingPolicyCategoryName); + + // Transport Logger + TestLogger transportLogger = factory.GetLogger(PipelineTransportCategoryName); + + // Request Events + IEnumerable logs = messageLogger.EventsById(1); + Assert.AreEqual(4, logs.Count()); + + // Exception Response Events + logs = transportLogger.EventsById(18); + Assert.AreEqual(4, logs.Count()); + foreach (LoggerEvent responseEventLog in logs) + { + Assert.AreEqual(LogLevel.Information, responseEventLog.LogLevel); + Assert.AreEqual("ExceptionResponse", responseEventLog.EventId.Name); + } + + // No other events should have been logged + Assert.AreEqual(4, messageLogger.Logs.Count()); + Assert.AreEqual(4, transportLogger.Logs.Count()); + } + + [Test] + public async Task LogsRequestAndRetryToILogger() + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientPipelineOptions options = new() { ClientLoggingOptions = new() { LoggerFactory = factory } }; + ClientPipeline pipeline = ClientPipeline.Create(options); + + int responseNum = 0; + using TestServer testServer = new( + async context => + { + switch (responseNum) + { + case 0: + context.Response.StatusCode = 429; + await context.Response.WriteAsync("Try again"); + break; + default: + context.Response.StatusCode = 201; + await context.Response.WriteAsync("Success"); + break; + } + responseNum++; + }); + + using PipelineMessage message = pipeline.CreateMessage(); + message.Request.Uri = testServer.Address; + message.BufferResponse = true; + + await pipeline.SendSyncOrAsync(message, IsAsync); + + // Message logger + TestLogger messageLogger = factory.GetLogger(LoggingPolicyCategoryName); + + // Retry Logger + TestLogger retryLogger = factory.GetLogger(RetryPolicyCategoryName); + + // Request Events + IEnumerable args = messageLogger.EventsById(1); + Assert.AreEqual(2, args.Count()); + + // Retry event + LoggerEvent arg = retryLogger.SingleEventById(10); + Assert.AreEqual("RequestRetrying", arg.EventId.Name); + + // Error response event + arg = messageLogger.SingleEventById(8); + Assert.AreEqual("ErrorResponse", arg.EventId.Name); + + // Response event + arg = messageLogger.SingleEventById(5); + Assert.AreEqual("Response", arg.EventId.Name); + + // No other events should have been logged + Assert.AreEqual(4, messageLogger.Logs.Count()); + Assert.AreEqual(1, retryLogger.Logs.Count()); + } + + #endregion + #region Test parallel connections [Test] diff --git a/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineTests.cs b/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineTests.cs index a615fca52baa..a6fade00ebb3 100644 --- a/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineTests.cs +++ b/sdk/core/System.ClientModel/tests/Pipeline/ClientPipelineTests.cs @@ -205,6 +205,7 @@ public async Task CanCreateWithClientAuthorAndClientUserPolicies() ClientPipelineOptions options = new() { RetryPolicy = new ObservablePolicy("RetryPolicy"), + MessageLoggingPolicy = new ObservablePolicy("LoggingPolicy"), Transport = new ObservableTransport("Transport") }; @@ -246,7 +247,7 @@ public async Task CanCreateWithClientAuthorAndClientUserPolicies() List observations = ObservablePolicy.GetData(message); int index = 0; - Assert.AreEqual(27, observations.Count); + Assert.AreEqual(29, observations.Count); Assert.AreEqual("Request:ClientPerCallPolicyA", observations[index++]); Assert.AreEqual("Request:ClientPerCallPolicyB", observations[index++]); @@ -262,6 +263,8 @@ public async Task CanCreateWithClientAuthorAndClientUserPolicies() Assert.AreEqual("Request:UserPerTryPolicyA", observations[index++]); Assert.AreEqual("Request:UserPerTryPolicyB", observations[index++]); + Assert.AreEqual("Request:LoggingPolicy", observations[index++]); + Assert.AreEqual("Request:ClientBeforeTransportPolicyA", observations[index++]); Assert.AreEqual("Request:ClientBeforeTransportPolicyB", observations[index++]); @@ -276,6 +279,8 @@ public async Task CanCreateWithClientAuthorAndClientUserPolicies() Assert.AreEqual("Response:ClientBeforeTransportPolicyB", observations[index++]); Assert.AreEqual("Response:ClientBeforeTransportPolicyA", observations[index++]); + Assert.AreEqual("Response:LoggingPolicy", observations[index++]); + Assert.AreEqual("Response:UserPerTryPolicyB", observations[index++]); Assert.AreEqual("Response:UserPerTryPolicyA", observations[index++]); @@ -296,6 +301,7 @@ public async Task RequestOptionsCanCustomizePipeline() { ClientPipelineOptions pipelineOptions = new ClientPipelineOptions(); pipelineOptions.RetryPolicy = new ObservablePolicy("RetryPolicy"); + pipelineOptions.MessageLoggingPolicy = new ObservablePolicy("LoggingPolicy"); pipelineOptions.Transport = new ObservableTransport("Transport"); ClientPipeline pipeline = ClientPipeline.Create(pipelineOptions); @@ -312,13 +318,15 @@ public async Task RequestOptionsCanCustomizePipeline() List observations = ObservablePolicy.GetData(message); int index = 0; - Assert.AreEqual(9, observations.Count); + Assert.AreEqual(11, observations.Count); Assert.AreEqual("Request:A", observations[index++]); Assert.AreEqual("Request:RetryPolicy", observations[index++]); Assert.AreEqual("Request:B", observations[index++]); + Assert.AreEqual("Request:LoggingPolicy", observations[index++]); Assert.AreEqual("Request:C", observations[index++]); Assert.AreEqual("Transport:Transport", observations[index++]); Assert.AreEqual("Response:C", observations[index++]); + Assert.AreEqual("Response:LoggingPolicy", observations[index++]); Assert.AreEqual("Response:B", observations[index++]); Assert.AreEqual("Response:RetryPolicy", observations[index++]); Assert.AreEqual("Response:A", observations[index++]); diff --git a/sdk/core/System.ClientModel/tests/Pipeline/ClientRetryPolicyTests.cs b/sdk/core/System.ClientModel/tests/Pipeline/ClientRetryPolicyTests.cs index a6c893724a01..16db2e408104 100644 --- a/sdk/core/System.ClientModel/tests/Pipeline/ClientRetryPolicyTests.cs +++ b/sdk/core/System.ClientModel/tests/Pipeline/ClientRetryPolicyTests.cs @@ -47,7 +47,7 @@ public async Task DoesNotExceedRetryCount() { ClientPipelineOptions options = new() { - Transport = new MockPipelineTransport("Transport", i => 500) + Transport = new MockPipelineTransport("Transport", _ => new MockPipelineResponse(500)) }; ClientPipeline pipeline = ClientPipeline.Create(options); @@ -76,7 +76,7 @@ public async Task CanConfigureMaxRetryCount() ClientPipelineOptions options = new() { RetryPolicy = new MockRetryPolicy(maxRetryCount, i => TimeSpan.FromMilliseconds(10)), - Transport = new MockPipelineTransport("Transport", i => 500) + Transport = new MockPipelineTransport("Transport", _ => new MockPipelineResponse(500)) }; ClientPipeline pipeline = ClientPipeline.Create(options); @@ -169,16 +169,17 @@ public void RespectsRetryAfterDateHeader() public async Task ShouldRetryIsCalledOnlyForErrors() { Exception retriableException = new IOException(); + int retryCount = 0; - MockRetryPolicy retryPolicy = new MockRetryPolicy(); - MockPipelineTransport transport = new MockPipelineTransport("Transport", responseFactory); + MockRetryPolicy retryPolicy = new(); + MockPipelineTransport transport = new("Transport", responseFactory); - int responseFactory(int i) - => i switch + MockPipelineResponse responseFactory(PipelineMessage m) + => retryCount++ switch { - 0 => 500, + 0 => new MockPipelineResponse(500), 1 => throw retriableException, - 2 => 200, + 2 => new MockPipelineResponse(200), _ => throw new InvalidOperationException(), }; @@ -190,9 +191,9 @@ int responseFactory(int i) ClientPipeline pipeline = ClientPipeline.Create(options); // Validate the state of the retry policy at the transport. - transport.OnSendingRequest = i => + transport.OnSendingRequest = _ => { - switch (i) + switch (retryCount) { case 0: Assert.IsFalse(retryPolicy.ShouldRetryCalled); @@ -237,16 +238,17 @@ int responseFactory(int i) public async Task CallbacksAreCalledForErrorResponseAndException() { Exception retriableException = new IOException(); + int retryCount = 0; - MockRetryPolicy retryPolicy = new MockRetryPolicy(); - MockPipelineTransport transport = new MockPipelineTransport("Transport", responseFactory); + MockRetryPolicy retryPolicy = new(); + MockPipelineTransport transport = new("Transport", responseFactory); - int responseFactory(int i) - => i switch + MockPipelineResponse responseFactory(PipelineMessage m) + => retryCount++ switch { - 0 => 500, + 0 => new MockPipelineResponse(500), 1 => throw retriableException, - 2 => 200, + 2 => new MockPipelineResponse(200), _ => throw new InvalidOperationException(), }; @@ -260,7 +262,7 @@ int responseFactory(int i) // Validate the state of the retry policy at the transport. transport.OnSendingRequest = i => { - switch (i) + switch (retryCount) { case 0: Assert.IsTrue(retryPolicy.OnSendingRequestCalled); @@ -338,17 +340,18 @@ public void RethrowsAggregateExceptionAfterMaxRetryCount() new IOException(), new IOException(), new IOException() }; + int retryCount = 0; MockRetryPolicy retryPolicy = new MockRetryPolicy(); MockPipelineTransport transport = new MockPipelineTransport("Transport", responseFactory); - int responseFactory(int i) - => i switch + MockPipelineResponse responseFactory(PipelineMessage i) + => retryCount++ switch { - 0 => throw exceptions[i], - 1 => throw exceptions[i], - 2 => throw exceptions[i], - 3 => throw exceptions[i], + 0 => throw exceptions[0], + 1 => throw exceptions[1], + 2 => throw exceptions[2], + 3 => throw exceptions[3], _ => throw new InvalidOperationException(), }; diff --git a/sdk/core/System.ClientModel/tests/Pipeline/MessageLoggingPolicyTests.cs b/sdk/core/System.ClientModel/tests/Pipeline/MessageLoggingPolicyTests.cs new file mode 100644 index 000000000000..d9453b7ff401 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/Pipeline/MessageLoggingPolicyTests.cs @@ -0,0 +1,644 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.ClientModel.Tests.TestFramework; +using System.Collections.Generic; +using System.Diagnostics.Tracing; +using System.IO; +using System.Text; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using ClientModel.Tests; +using ClientModel.Tests.Mocks; +using Microsoft.Extensions.Logging; +using NUnit.Framework; +using SyncAsyncTestBase = ClientModel.Tests.SyncAsyncTestBase; + +namespace System.ClientModel.Tests.Pipeline; + +// Avoid running these tests in parallel with anything else that's sharing the event source +[NonParallelizable] +public class MessageLoggingPolicyTests(bool isAsync) : SyncAsyncTestBase(isAsync) +{ + private const int RequestEvent = 1; + private const int RequestContentEvent = 2; + private const int ResponseEvent = 5; + private const int ResponseContentEvent = 6; + private const int ErrorResponseEvent = 8; + private const int ErrorResponseContentEvent = 9; + private const int ResponseContentBlockEvent = 11; + private const int ErrorResponseContentBlockEvent = 12; + private const int ResponseContentTextEvent = 13; + private const int ErrorResponseContentTextEvent = 14; + private const int ResponseContentTextBlockEvent = 15; + private const int ErrorResponseContentTextBlockEvent = 16; + private const int RequestContentTextEvent = 17; + private const string LoggingPolicyCategoryName = "System.ClientModel.Primitives.MessageLoggingPolicy"; + private const string PipelineTransportCategoryName = "System.ClientModel.Primitives.PipelineTransport"; + private const string RetryPolicyCategoryName = "System.ClientModel.Primitives.ClientRetryPolicy"; + private const string SystemClientModelEventSourceName = "System.ClientModel"; + private readonly MockResponseHeaders _defaultHeaders = new(new Dictionary() + { + { "Custom-Response-Header", "custom-response-header-value" }, + { "Date", "4/29/2024" }, + { "ETag", "version1" } + }); + private readonly MockResponseHeaders _defaultTextHeaders = new(new Dictionary() + { + { "Custom-Response-Header", "custom-response-header-value" }, + { "Content-Type", "text/plain" }, + { "Date", "4/29/2024" }, + { "ETag", "version1" } + }); + + [Test] + public async Task OptionsCanBeUpdatedUntilFrozenByPipeline() + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientLoggingOptions loggingOptions = new() + { + LoggerFactory = factory + }; + + MessageLoggingPolicy loggingPolicy = new(loggingOptions); + + ClientPipelineOptions options = new() + { + MessageLoggingPolicy = loggingPolicy, + Transport = new MockPipelineTransport("Transport", [200]) + }; + + loggingOptions.EnableMessageContentLogging = true; + + ClientPipeline pipeline = ClientPipeline.Create(options); + PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + message.Request.Content = BinaryContent.Create(new BinaryData([1,2,3])); + + await pipeline.SendSyncOrAsync(message, IsAsync); + + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + logger.GetAndValidateSingleEvent(RequestContentEvent, "RequestContent", LogLevel.Debug); + } + + [TestCase(true, true)] + [TestCase(false, true)] + [TestCase(true, false)] + [TestCase(false, false)] + [Test] + public async Task ContentIsNotLoggedByDefaultToEventSource(bool isError, bool asText) + { + using TestClientEventListener listener = new(); + ClientLoggingOptions loggingOptions = new(); + + await SendSimpleRequestResponseSyncOrAsync(isError, loggingOptions, asText, IsAsync); + + listener.AssertNoContentLogged(); + } + + [TestCase(true, true)] + [TestCase(false, true)] + [TestCase(true, false)] + [TestCase(false, false)] + [Test] + public async Task ContentIsNotLoggedByDefaultToILogger(bool isError, bool asText) + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientLoggingOptions loggingOptions = new() + { + LoggerFactory = factory + }; + + await SendSimpleRequestResponseSyncOrAsync(isError, loggingOptions, asText, IsAsync); + + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + logger.AssertNoContentLogged(); + } + + [TestCase(true, true)] + [TestCase(false, true)] + [TestCase(true, false)] + [TestCase(false, false)] + [Test] + public async Task ContentIsNotLoggedWhenDisabledToEventSource(bool isError, bool asText) + { + using TestClientEventListener listener = new(); + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = false + }; + + await SendSimpleRequestResponseSyncOrAsync(isError, loggingOptions, asText, IsAsync); + + listener.AssertNoContentLogged(); + } + + [TestCase(true, true)] + [TestCase(false, true)] + [TestCase(true, false)] + [TestCase(false, false)] + [Test] + public async Task ContentIsNotLoggedWhenDisabledToILogger(bool isError, bool asText) + { + using TestLoggingFactory factory = new(LogLevel.Debug); + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = false, + LoggerFactory = factory + }; + + await SendSimpleRequestResponseSyncOrAsync(isError, loggingOptions, asText, IsAsync); + + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + logger.AssertNoContentLogged(); + } + + [TestCase(true, true)] + [TestCase(false, true)] + [TestCase(true, false)] + [TestCase(false, false)] + [Test] + public async Task ContentIsNotLoggedInBlocksWhenDisabledToEventSource(bool isError, bool asText) + { + using TestClientEventListener listener = new(); + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = false + }; + + MockPipelineResponse response = new(isError ? 500 : 200, mockHeaders: asText ? _defaultTextHeaders : _defaultHeaders) + { + ContentStream = new NonSeekableMemoryStream(Encoding.UTF8.GetBytes("Hello world")) + }; + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", i => response), + ClientLoggingOptions = loggingOptions, + RetryPolicy = new ObservablePolicy("RetryPolicy") + }; + + ClientPipeline pipeline = ClientPipeline.Create(options); + PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + message.Request.Content = BinaryContent.Create(new BinaryData(Encoding.UTF8.GetBytes("Hello world"))); + + await pipeline.SendSyncOrAsync(message, IsAsync); + + listener.AssertNoContentLogged(); + } + + [TestCase(true, true)] + [TestCase(false, true)] + [TestCase(true, false)] + [TestCase(false, false)] + [Test] + public async Task ContentIsNotLoggedInBlocksWhenDisabledToILogger(bool isError, bool asText) + { + using TestLoggingFactory factory = new(LogLevel.Debug); + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = false, + LoggerFactory = factory + }; + + MockPipelineResponse response = new(isError ? 500 : 200, mockHeaders: asText ? _defaultTextHeaders : _defaultHeaders) + { + ContentStream = new NonSeekableMemoryStream(Encoding.UTF8.GetBytes("Hello world")) + }; + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", i => response), + ClientLoggingOptions = loggingOptions, + RetryPolicy = new ObservablePolicy("RetryPolicy") + }; + + ClientPipeline pipeline = ClientPipeline.Create(options); + PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + message.Request.Content = BinaryContent.Create(new BinaryData(Encoding.UTF8.GetBytes("Hello world"))); + + await pipeline.SendSyncOrAsync(message, IsAsync); + + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + logger.AssertNoContentLogged(); + } + + [TestCase(true, true)] + [TestCase(false, true)] + [TestCase(true, false)] + [TestCase(false, false)] + [Test] + public async Task ContentIsNotLoggedWhenEventSourceIsDisabled(bool isError, bool asText) + { + using TestEventListenerWarning listener = new(); + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true + }; + + await SendSimpleRequestResponseSyncOrAsync(isError, loggingOptions, asText, IsAsync); + + listener.AssertNoContentLogged(); + } + + [TestCase(true)] + [TestCase(false)] + [Test] + public async Task ContentEventIsNotWrittenWhenThereIsNoContentToEventSource(bool isError) + { + using TestClientEventListener listener = new(); + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true + }; + + MockPipelineResponse response = new(isError ? 500 : 200); + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", i => response), + ClientLoggingOptions = loggingOptions, + RetryPolicy = new ObservablePolicy("RetryPolicy") + }; + + ClientPipeline pipeline = ClientPipeline.Create(options); + PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + + await pipeline.SendSyncOrAsync(message, IsAsync); + + listener.AssertNoContentLogged(); + } + + [TestCase(true)] + [TestCase(false)] + [Test] + public async Task ContentEventIsNotWrittenWhenThereIsNoContentToILogger(bool isError) + { + using TestLoggingFactory factory = new(LogLevel.Debug); + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true, + LoggerFactory = factory + }; + + MockPipelineResponse response = new(isError ? 500 : 200); + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", i => response), + ClientLoggingOptions = loggingOptions, + RetryPolicy = new ObservablePolicy("RetryPolicy") + }; + + ClientPipeline pipeline = ClientPipeline.Create(options); + PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + + await pipeline.SendSyncOrAsync(message, IsAsync); + + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + logger.AssertNoContentLogged(); + } + + [Test] + public async Task RequestContentLogsAreLimitedInLengthToEventSource() + { + using TestClientEventListener listener = new(); + + var response = new MockPipelineResponse(500); + byte[] requestContent = [1, 2, 3, 4, 5, 6, 7, 8]; + byte[] requestContentLimited = [1, 2, 3, 4, 5]; + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true, + MessageContentSizeLimit = 5 + }; + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", i => response), + ClientLoggingOptions = loggingOptions, + RetryPolicy = new ObservablePolicy("RetryPolicy") + }; + + ClientPipeline pipeline = ClientPipeline.Create(options); + PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + message.Request.Content = BinaryContent.Create(new BinaryData(requestContent)); + + await pipeline.SendSyncOrAsync(message, IsAsync); + + EventWrittenEventArgs logEvent = listener.GetAndValidateSingleEvent(RequestContentEvent, "RequestContent", EventLevel.Verbose, SystemClientModelEventSourceName); // RequestContentEvent + Assert.AreEqual(requestContentLimited, logEvent.GetProperty("content")); + CollectionAssert.IsEmpty(listener.EventsById(RequestContentTextEvent)); + } + + [Test] + public async Task RequestContentLogsAreLimitedInLengthToILogger() + { + using TestLoggingFactory factory = new(LogLevel.Debug); + + var response = new MockPipelineResponse(500); + byte[] requestContent = [1, 2, 3, 4, 5, 6, 7, 8]; + byte[] requestContentLimited = [1, 2, 3, 4, 5]; + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true, + MessageContentSizeLimit = 5, + LoggerFactory = factory + }; + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", i => response), + ClientLoggingOptions = loggingOptions, + RetryPolicy = new ObservablePolicy("RetryPolicy") + }; + + ClientPipeline pipeline = ClientPipeline.Create(options); + PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + message.Request.Content = BinaryContent.Create(new BinaryData(requestContent)); + + await pipeline.SendSyncOrAsync(message, IsAsync); + + TestLogger logger = factory!.GetLogger(LoggingPolicyCategoryName); + LoggerEvent logEvent = logger.GetAndValidateSingleEvent(RequestContentEvent, "RequestContent", LogLevel.Debug); + Assert.AreEqual(requestContentLimited, logEvent.GetValueFromArguments("content")); + CollectionAssert.IsEmpty(logger.EventsById(RequestContentTextEvent)); + } + + [Test] + public async Task RequestContentTextLogsAreLimitedInLengthToEventSource() + { + using TestClientEventListener listener = new(); + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true, + MessageContentSizeLimit = 5 + }; + + MockPipelineResponse response = new(200, mockHeaders: _defaultTextHeaders); + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", i => response), + ClientLoggingOptions = loggingOptions, + RetryPolicy = new ObservablePolicy("RetryPolicy") + }; + + ClientPipeline pipeline = ClientPipeline.Create(options); + PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + message.Request.Content = BinaryContent.Create(new BinaryData("Hello world")); + message.Request.Headers.Add("Content-Type", "text/plain"); + + await pipeline.SendSyncOrAsync(message, IsAsync); + + EventWrittenEventArgs requestEvent = listener!.GetAndValidateSingleEvent(RequestContentTextEvent, "RequestContentText", EventLevel.Verbose, SystemClientModelEventSourceName); + Assert.AreEqual("Hello", requestEvent.GetProperty("content")); + + CollectionAssert.IsEmpty(listener!.EventsById(RequestContentEvent)); + } + + [Test] + public async Task RequestContentTextLogsAreLimitedInLengthToILogger() + { + using TestLoggingFactory factory = new(LogLevel.Debug); + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true, + MessageContentSizeLimit = 5, + LoggerFactory = factory + }; + + MockPipelineResponse response = new(500, mockHeaders: _defaultTextHeaders); + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", i => response), + ClientLoggingOptions = loggingOptions, + RetryPolicy = new ObservablePolicy("RetryPolicy") + }; + + ClientPipeline pipeline = ClientPipeline.Create(options); + PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + message.Request.Content = BinaryContent.Create(new BinaryData("Hello world")); + message.Request.Headers.Add("Content-Type", "text/plain"); + + await pipeline.SendSyncOrAsync(message, IsAsync); + + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + LoggerEvent logEvent = logger.GetAndValidateSingleEvent(RequestContentTextEvent, "RequestContentText", LogLevel.Debug); + Assert.AreEqual("Hello", logEvent.GetValueFromArguments("content")); + CollectionAssert.IsEmpty(logger.EventsById(RequestContentEvent)); // RequestContentEvent + } + + [Test] + public async Task SeekableTextResponsesAreLimitedInLengthToEventSource() + { + using TestClientEventListener listener = new(); + + ClientLoggingOptions loggingOptions = new() + { + MessageContentSizeLimit = 5, + EnableMessageContentLogging = true + }; + + MockPipelineResponse response = new(200, mockHeaders: _defaultTextHeaders); + await SendRequestWithStreamingResponseSyncOrAsync(response, true, loggingOptions); + + EventWrittenEventArgs contentEvent = listener.GetAndValidateSingleEvent(ResponseContentTextEvent, "ResponseContentText", EventLevel.Verbose, SystemClientModelEventSourceName); + Assert.AreEqual("Hello", contentEvent.GetProperty("content")); + } + + [Test] + public async Task SeekableTextResponsesAreLimitedInLengthToILogger() + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientLoggingOptions loggingOptions = new() + { + MessageContentSizeLimit = 5, + EnableMessageContentLogging = true, + LoggerFactory = factory + }; + + MockPipelineResponse response = new(200, mockHeaders: _defaultTextHeaders); + await SendRequestWithStreamingResponseSyncOrAsync(response, true, loggingOptions); + + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + LoggerEvent contentEvent = logger.GetAndValidateSingleEvent(13, "ResponseContentText", LogLevel.Debug); + Assert.AreEqual("Hello", contentEvent.GetValueFromArguments("content")); + } + + [Test] + public async Task NonSeekableResponsesAreLimitedInLengthEventSource() + { + using TestClientEventListener listener = new(); + ClientLoggingOptions loggingOptions = new() + { + MessageContentSizeLimit = 5, + EnableMessageContentLogging = true + }; + MockPipelineResponse response = new(200, mockHeaders: _defaultHeaders); + + await SendRequestWithStreamingResponseSyncOrAsync(response, false, loggingOptions); + + EventWrittenEventArgs responseEvent = listener.GetAndValidateSingleEvent(11, "ResponseContentBlock", EventLevel.Verbose, SystemClientModelEventSourceName); + Assert.AreEqual(Encoding.UTF8.GetBytes("Hello"), responseEvent.GetProperty("content")); + } + + [Test] + public async Task NonSeekableResponsesAreLimitedInLengthILogger() + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientLoggingOptions loggingOptions = new() + { + MessageContentSizeLimit = 5, + EnableMessageContentLogging = true, + LoggerFactory = factory + }; + MockPipelineResponse response = new(200, mockHeaders: _defaultHeaders); + + await SendRequestWithStreamingResponseSyncOrAsync(response, false, loggingOptions); + + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + LoggerEvent responseEvent = logger.GetAndValidateSingleEvent(11, "ResponseContentBlock", LogLevel.Debug); + Assert.AreEqual(Encoding.UTF8.GetBytes("Hello"), responseEvent.GetValueFromArguments("content")); + } + + #region Helpers + + private class TestEventListenerWarning : TestClientEventListener + { + protected override void OnEventSourceCreated(EventSource eventSource) + { + if (eventSource.Name == "System.ClientModel") + { + Console.WriteLine("Warning"); + EnableEvents(eventSource, EventLevel.Warning); + } + } + } + + private async Task SendRequestWithStreamingResponseSyncOrAsync(MockPipelineResponse response, + bool isSeekable, + ClientLoggingOptions loggingOptions) + { + byte[] responseContent = Encoding.UTF8.GetBytes("Hello world"); + if (isSeekable) + { + response.ContentStream = new MemoryStream(responseContent); + } + else + { + response.ContentStream = new NonSeekableMemoryStream(responseContent); + } + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", i => response), + ClientLoggingOptions = loggingOptions, + RetryPolicy = new ObservablePolicy("RetryPolicy") + }; + + ClientPipeline pipeline = ClientPipeline.Create(options); + + PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + + // These tests are essentially testing whether the logging policy works + // correctly when responses are buffered (memory stream) and unbuffered + // (non-seekable). In order to validate the intent of the test, we set + // message.BufferResponse accordingly here. + message.BufferResponse = isSeekable; + + await pipeline.SendSyncOrAsync(message, IsAsync); + + var buffer = new byte[11]; + + if (IsAsync) + { +#if NET462 + Assert.AreEqual(6, await response.ContentStream.ReadAsync(buffer, 5, 6)); + Assert.AreEqual(5, await response.ContentStream.ReadAsync(buffer, 6, 5)); + Assert.AreEqual(0, await response.ContentStream.ReadAsync(buffer, 0, 5)); +#else + Assert.AreEqual(6, await response.ContentStream.ReadAsync(buffer.AsMemory(5, 6))); + Assert.AreEqual(5, await response.ContentStream.ReadAsync(buffer.AsMemory(6, 5))); + Assert.AreEqual(0, await response.ContentStream.ReadAsync(buffer.AsMemory(0, 5))); +#endif + } + else + { + Assert.AreEqual(6, response.ContentStream.Read(buffer, 5, 6)); + Assert.AreEqual(5, response.ContentStream.Read(buffer, 6, 5)); + Assert.AreEqual(0, response.ContentStream.Read(buffer, 0, 5)); + } + } + + private async Task SendSimpleRequestResponseSyncOrAsync(bool isError, ClientLoggingOptions loggingOptions, bool contentAsText, bool isAsync) + { + MockPipelineResponse response = new(isError ? 500 : 200); + response.SetContent([1, 2, 3]); + + loggingOptions.AllowedHeaderNames.Add("Custom-Header"); + loggingOptions.AllowedHeaderNames.Add("Custom-Response-Header"); + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", i => response), + ClientLoggingOptions = loggingOptions, + RetryPolicy = new ObservablePolicy("RetryPolicy") + }; + + ClientPipeline pipeline = ClientPipeline.Create(options); + PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + message.Request.Headers.Add("Custom-Header", "custom-header-value"); + message.Request.Headers.Add("Date", "08/16/2024"); + + if (contentAsText) + { + response.SetContent("ResponseAsText"); + message.Request.Content = BinaryContent.Create(new BinaryData("RequestAsText")); + message.Request.Headers.Add("Content-Type", "text/plain"); + } + else + { + response.SetContent([1, 2, 3]); + message.Request.Content = BinaryContent.Create(new BinaryData(Encoding.UTF8.GetBytes("Hello world"))); + } + + await pipeline.SendSyncOrAsync(message, IsAsync); + } + + #endregion +} diff --git a/sdk/core/System.ClientModel/tests/Pipeline/PipelineTransportFunctionalTests.cs b/sdk/core/System.ClientModel/tests/Pipeline/PipelineTransportFunctionalTests.cs index d62e71147f8f..dacce0a9e558 100644 --- a/sdk/core/System.ClientModel/tests/Pipeline/PipelineTransportFunctionalTests.cs +++ b/sdk/core/System.ClientModel/tests/Pipeline/PipelineTransportFunctionalTests.cs @@ -1,19 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using Azure.Core.TestFramework; -using ClientModel.Tests; -using ClientModel.Tests.Mocks; -using Microsoft.AspNetCore.Http.Features; -using NUnit.Framework; using System.ClientModel.Primitives; -using System.Collections.Generic; using System.IO; -using System.Net; -using System.Net.Http; using System.Text; using System.Threading; using System.Threading.Tasks; +using Azure.Core.TestFramework; +using ClientModel.Tests.Mocks; +using Microsoft.AspNetCore.Http.Features; +using NUnit.Framework; using SyncAsyncTestBase = ClientModel.Tests.SyncAsyncTestBase; namespace System.ClientModel.Tests.Pipeline; diff --git a/sdk/core/System.ClientModel/tests/Samples/LoggingSamples.cs b/sdk/core/System.ClientModel/tests/Samples/LoggingSamples.cs new file mode 100644 index 000000000000..ae7672b3e0ef --- /dev/null +++ b/sdk/core/System.ClientModel/tests/Samples/LoggingSamples.cs @@ -0,0 +1,185 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Primitives; +using System.Diagnostics.Tracing; +using Maps; +using Microsoft.Extensions.Logging; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Samples; + +public class LoggingSamples +{ + [Test] + public void UseILoggerFactoryToCaptureLogs() + { + #region Snippet:UseILoggerFactoryToCaptureLogs + using ILoggerFactory factory = LoggerFactory.Create(builder => + { + builder.AddConsole().SetMinimumLevel(LogLevel.Information); + }); + + ClientLoggingOptions loggingOptions = new() + { + LoggerFactory = factory + }; + + MapsClientOptions options = new() + { + ClientLoggingOptions = loggingOptions + }; + + // Create and use client as usual + #endregion + } + + [Test] + public void UseEventSourceToCaptureLogs() + { + #region Snippet:UseEventSourceToCaptureLogs + // In order for an event listener to collect logs, it must be in scope and active + // while the client library is in use. If the listener is disposed or otherwise + // out of scope, logs cannot be collected. + using ConsoleWriterEventListener listener = new(); + + // Create and use client as usual + #endregion + } + + [Test] + public void LoggingRedactedHeaderILogger() + { + #region Snippet:LoggingRedactedHeaderILogger + using ILoggerFactory factory = LoggerFactory.Create(builder => + { + builder.AddConsole(); + }); + + ClientLoggingOptions loggingOptions = new() + { + LoggerFactory = factory + }; + loggingOptions.AllowedHeaderNames.Add("Request-Id"); + loggingOptions.AllowedQueryParameters.Add("api-version"); + + MapsClientOptions options = new() + { + ClientLoggingOptions = loggingOptions + }; + #endregion + } + + [Test] + public void LoggingRedactedHeaderEventSource() + { + #region Snippet:LoggingRedactedHeaderEventSource + using ConsoleWriterEventListener listener = new(); + + ClientLoggingOptions loggingOptions = new(); + loggingOptions.AllowedHeaderNames.Add("Request-Id"); + loggingOptions.AllowedQueryParameters.Add("api-version"); + + MapsClientOptions options = new() + { + ClientLoggingOptions = loggingOptions + }; + #endregion + } + + [Test] + public void LoggingAllRedactedHeadersILogger() + { + #region Snippet:LoggingAllRedactedHeadersILogger + using ILoggerFactory factory = LoggerFactory.Create(builder => + { + builder.AddConsole(); + }); + + ClientLoggingOptions loggingOptions = new() + { + LoggerFactory = factory + }; + loggingOptions.AllowedHeaderNames.Add("*"); + loggingOptions.AllowedQueryParameters.Add("*"); + + MapsClientOptions options = new() + { + ClientLoggingOptions = loggingOptions + }; + #endregion + } + + [Test] + public void LoggingAllRedactedHeadersEventSource() + { + #region Snippet:LoggingAllRedactedHeadersEventSource + using ConsoleWriterEventListener listener = new(); + + ClientLoggingOptions loggingOptions = new(); + loggingOptions.AllowedHeaderNames.Add("*"); + loggingOptions.AllowedQueryParameters.Add("*"); + + MapsClientOptions options = new() + { + ClientLoggingOptions = loggingOptions + }; + #endregion + } + + [Test] + public void EnableContentLoggingILogger() + { + #region Snippet:EnableContentLoggingILogger + using ILoggerFactory factory = LoggerFactory.Create(builder => + { + builder.AddConsole().SetMinimumLevel(LogLevel.Debug); + }); + + ClientLoggingOptions loggingOptions = new() + { + LoggerFactory = factory, + EnableMessageContentLogging = true + }; + + MapsClientOptions options = new() + { + ClientLoggingOptions = loggingOptions + }; + #endregion + } + + [Test] + public void EnableContentLoggingEventSource() + { + #region Snippet:EnableContentLoggingEventSource + using ConsoleWriterEventListener listener = new(); + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true + }; + + MapsClientOptions options = new() + { + ClientLoggingOptions = loggingOptions + }; + #endregion + } + + internal class ConsoleWriterEventListener : EventListener + { + protected override void OnEventSourceCreated(EventSource eventSource) + { + if (eventSource.Name == "System-ClientModel") + { + EnableEvents(eventSource, EventLevel.Informational); + } + } + + protected override void OnEventWritten(EventWrittenEventArgs eventData) + { + Console.WriteLine(eventData.EventId + " " + eventData.EventName + " " + DateTime.Now); + } + } +} diff --git a/sdk/core/System.ClientModel/tests/System.ClientModel.Tests.csproj b/sdk/core/System.ClientModel/tests/System.ClientModel.Tests.csproj index 838309afa64f..799857a08a18 100644 --- a/sdk/core/System.ClientModel/tests/System.ClientModel.Tests.csproj +++ b/sdk/core/System.ClientModel/tests/System.ClientModel.Tests.csproj @@ -13,6 +13,8 @@ + + diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Logging/LoggerEvent.cs b/sdk/core/System.ClientModel/tests/TestFramework/Logging/LoggerEvent.cs new file mode 100644 index 000000000000..5c3758c20b06 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/TestFramework/Logging/LoggerEvent.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.Extensions.Logging; + +namespace System.ClientModel.Tests.TestFramework; + +public class LoggerEvent +{ + public LoggerEvent(LogLevel logLevel, + string message, + Exception? exception, + EventId eventId, + IReadOnlyList> arguments) + { + LogLevel = logLevel; + Message = message; + Exception = exception; + EventId = eventId; + Arguments = arguments; + } + + public LogLevel LogLevel { get; } + public string Message { get; } + public Exception? Exception { get; } + public EventId EventId { get; } + public IReadOnlyList> Arguments { get; } + + public T GetValueFromArguments(string key) + { + var value = Arguments.Single(kvp => kvp.Key == key).Value; + return (T)value!; + } +} diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Logging/TestClientEventListener.cs b/sdk/core/System.ClientModel/tests/TestFramework/Logging/TestClientEventListener.cs new file mode 100644 index 000000000000..fa2726b7e36d --- /dev/null +++ b/sdk/core/System.ClientModel/tests/TestFramework/Logging/TestClientEventListener.cs @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics.Tracing; +using System.Globalization; +using System.Linq; +using System.Text; +using Azure.Core.TestFramework; +using NUnit.Framework; + +namespace ClientModel.Tests +{ + public class TestClientEventListener : EventListener + { + private volatile bool _disposed; + private readonly ConcurrentQueue _events = new(); + + public IEnumerable EventData => _events; + + /// + /// Creates an instance of . + /// + public TestClientEventListener() + { + } + + protected override void OnEventSourceCreated(EventSource eventSource) + { + // The event source names have to be hardcoded and cannot be configured at runtime by the constructor. This + // is because when an EventListener is instantiated, the OnEventWritten and OnEventSourceCreated callback methods can + // be called before the constructor has completed + // see: https://learn.microsoft.com/dotnet/api/system.diagnostics.tracing.eventlistener#remarks + if (eventSource.Name == "System.ClientModel") + { + EnableEvents(eventSource, EventLevel.Verbose); + } + } + + protected override void OnEventWritten(EventWrittenEventArgs eventData) + { + // Work around https://github.com/dotnet/corefx/issues/42600 + if (eventData.EventId == -1) + { + return; + } + + if (!_disposed) + { + // Make sure we can format the event + Format(eventData); + _events.Enqueue(eventData); + } + } + + public EventWrittenEventArgs GetAndValidateSingleEvent(int eventId, string expectedEventName, EventLevel expectedEventLevel, string expectedEventSourceName) + { + EventWrittenEventArgs args = SingleEventById(eventId); + Assert.AreEqual(expectedEventName, args.EventName); + Assert.AreEqual(expectedEventLevel, args.Level); + Assert.AreEqual(expectedEventSourceName, args.EventSource.Name); + string requestId = args.GetProperty("requestId"); + Assert.That(string.IsNullOrEmpty(requestId), Is.False); + return args; + } + + public EventWrittenEventArgs SingleEventById(int id, Func? filter = default) + { + return EventsById(id).Single(filter ?? (_ => true)); + } + + public void ValidateNumberOfEventsById(int eventId, int expectedNumEvents) + { + Assert.AreEqual(expectedNumEvents, EventsById(eventId).Count()); + } + + public IEnumerable EventsById(int id) + { + return _events.Where(e => e.EventId == id); + } + + public void AssertNoContentLogged() + { + CollectionAssert.IsEmpty(EventsById(2)); // RequestContentEvent + CollectionAssert.IsEmpty(EventsById(17)); // RequestContentTextEvent + + CollectionAssert.IsEmpty(EventsById(6)); // ResponseContentEvent + CollectionAssert.IsEmpty(EventsById(13)); // ResponseContentTextEvent + CollectionAssert.IsEmpty(EventsById(11)); // ResponseContentBlockEvent + CollectionAssert.IsEmpty(EventsById(15)); // ResponseContentTextBlockEvent + + CollectionAssert.IsEmpty(EventsById(9)); // ErrorResponseContentEvent + CollectionAssert.IsEmpty(EventsById(14)); // ErrorResponseContentTextEvent + CollectionAssert.IsEmpty(EventsById(12)); // ErrorResponseContentBlockEvent + CollectionAssert.IsEmpty(EventsById(16)); // ErrorResponseContentTextBlockEvent + } + + public override void Dispose() + { + _disposed = true; + base.Dispose(); + } + + #region Helpers + + private static string Format(EventWrittenEventArgs eventData) + { + var payloadArray = eventData.Payload?.ToArray() ?? Array.Empty(); + + ProcessPayloadArray(payloadArray); + + if (eventData.Message != null) + { + try + { + return string.Format(CultureInfo.InvariantCulture, eventData.Message, payloadArray); + } + catch (FormatException) + { + } + } + + var stringBuilder = new StringBuilder(); + stringBuilder.Append(eventData.EventName); + + if (!string.IsNullOrWhiteSpace(eventData.Message)) + { + stringBuilder.AppendLine(); + stringBuilder.Append(nameof(eventData.Message)).Append(" = ").Append(eventData.Message); + } + + if (eventData.PayloadNames != null) + { + for (int i = 0; i < eventData.PayloadNames.Count; i++) + { + stringBuilder.AppendLine(); + stringBuilder.Append(eventData.PayloadNames[i]).Append(" = ").Append(payloadArray[i]); + } + } + + return stringBuilder.ToString(); + } + + private static void ProcessPayloadArray(object?[] payloadArray) + { + for (int i = 0; i < payloadArray.Length; i++) + { + payloadArray[i] = FormatValue(payloadArray[i]); + } + } + + private static object? FormatValue(object? o) + { + if (o is byte[] bytes) + { + var stringBuilder = new StringBuilder(); + foreach (byte b in bytes) + { + stringBuilder.AppendFormat(CultureInfo.InvariantCulture, "{0:X2}", b); + } + + return stringBuilder.ToString(); + } + + return o; + } + + #endregion + } +} diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Logging/TestLogger.cs b/sdk/core/System.ClientModel/tests/TestFramework/Logging/TestLogger.cs new file mode 100644 index 000000000000..627d9a3d4f1e --- /dev/null +++ b/sdk/core/System.ClientModel/tests/TestFramework/Logging/TestLogger.cs @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel.Tests.TestFramework; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Extensions.Logging; +using NUnit.Framework; + +namespace ClientModel.Tests; + +public class TestLogger : ILogger +{ + private LogLevel _logLevel; + private readonly ConcurrentQueue _logs = new(); + + public TestLogger(LogLevel logLevel, string name) + { + _logLevel = logLevel; + Name = name; + } + + public IEnumerable Logs => _logs; + + public string Name { get; set; } + + public IDisposable BeginScope(TState state) + { + throw new NotImplementedException(); + } + + public bool IsEnabled(LogLevel logLevel) + { + return logLevel >= _logLevel; + } + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) + { + if (IsEnabled(logLevel)) + { + IReadOnlyList> arguments = state as IReadOnlyList> ?? new List>(); + var loggerEvent = new LoggerEvent(logLevel, formatter(state, exception), exception, eventId, arguments); + _logs.Enqueue(loggerEvent); + } + } + + public LoggerEvent GetAndValidateSingleEvent(int eventId, string expectedEventName, LogLevel expectedEventLevel) + { + LoggerEvent log = SingleEventById(eventId); + Assert.AreEqual(expectedEventName, log.EventId.Name); + Assert.AreEqual(expectedEventLevel, log.LogLevel); + string requestId = log.GetValueFromArguments("requestId"); + Assert.That(string.IsNullOrEmpty(requestId), Is.False); + return log; + } + + public LoggerEvent SingleEventById(int eventId, Func? filter = default) + { + return EventsById(eventId).Single(filter ?? (_ => true)); + } + + public void ValidateNumberOfEventsById(int eventId, int expectedNumEvents) + { + Assert.AreEqual(expectedNumEvents, EventsById(eventId).Count()); + } + + public IEnumerable EventsById(int eventId) + { + return _logs.Where(e => e.EventId.Id == eventId); + } + + public void AssertNoContentLogged() + { + CollectionAssert.IsEmpty(EventsById(2)); // RequestContentEvent + CollectionAssert.IsEmpty(EventsById(17)); // RequestContentTextEvent + + CollectionAssert.IsEmpty(EventsById(6)); // ResponseContentEvent + CollectionAssert.IsEmpty(EventsById(13)); // ResponseContentTextEvent + CollectionAssert.IsEmpty(EventsById(11)); // ResponseContentBlockEvent + CollectionAssert.IsEmpty(EventsById(15)); // ResponseContentTextBlockEvent + + CollectionAssert.IsEmpty(EventsById(9)); // ErrorResponseContentEvent + CollectionAssert.IsEmpty(EventsById(14)); // ErrorResponseContentTextEvent + CollectionAssert.IsEmpty(EventsById(12)); // ErrorResponseContentBlockEvent + CollectionAssert.IsEmpty(EventsById(16)); // ErrorResponseContentTextBlockEvent + } +} diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Logging/TestLoggingFactory.cs b/sdk/core/System.ClientModel/tests/TestFramework/Logging/TestLoggingFactory.cs new file mode 100644 index 000000000000..6b6b5f1c5cfb --- /dev/null +++ b/sdk/core/System.ClientModel/tests/TestFramework/Logging/TestLoggingFactory.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.Extensions.Logging; +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; + +namespace ClientModel.Tests; + +public class TestLoggingFactory : ILoggerFactory +{ + private readonly ConcurrentDictionary _loggers; + + public TestLoggingFactory(LogLevel level) + { + _loggers = new(); + LogLevel = level; + } + + public LogLevel LogLevel { get; } + + public void AddProvider(ILoggerProvider provider) + { + throw new NotImplementedException(); + } + + public ILogger CreateLogger(string categoryName) + { + return _loggers.GetOrAdd(categoryName, name => new TestLogger(LogLevel, name)); + } + + public TestLogger GetLogger(string categoryName) + { + return _loggers[categoryName]; + } + + public void Dispose() + { + _loggers.Clear(); + } +} diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineRequest.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineRequest.cs index bfafc9750091..7a2186539d94 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineRequest.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineRequest.cs @@ -20,6 +20,7 @@ public MockPipelineRequest() { _headers = new MockRequestHeaders(); _method = "GET"; + _uri = new Uri("https://www.example.com"); } protected override BinaryContent? ContentCore diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs index d04486feb670..0b7576d994a8 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineResponse.cs @@ -21,11 +21,11 @@ public class MockPipelineResponse : PipelineResponse private bool _disposed; - public MockPipelineResponse(int status = 0, string reasonPhrase = "") + public MockPipelineResponse(int status = 0, string reasonPhrase = "", MockResponseHeaders? mockHeaders = default) { _status = status; _reasonPhrase = reasonPhrase; - _headers = new MockResponseHeaders(); + _headers = mockHeaders ?? new MockResponseHeaders(); } public override int Status => _status; @@ -127,6 +127,7 @@ public override BinaryData BufferContent(CancellationToken cancellationToken = d // Less efficient FromStream method called here because it is a mock. // For intended production implementation, see HttpClientTransportResponse. _bufferedContent = BinaryData.FromStream(bufferStream); + _contentStream.Seek(0, SeekOrigin.Begin); return _bufferedContent; } @@ -158,6 +159,7 @@ public override async ValueTask BufferContentAsync(CancellationToken // Less efficient FromStream method called here because it is a mock. // For intended production implementation, see HttpClientTransportResponse. _bufferedContent = BinaryData.FromStream(bufferStream); + _contentStream.Seek(0, SeekOrigin.Begin); return _bufferedContent; } } diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineTransport.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineTransport.cs index 046edc105811..3a11a6cf2202 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineTransport.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockPipelineTransport.cs @@ -8,79 +8,69 @@ using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; namespace ClientModel.Tests.Mocks; public class MockPipelineTransport : PipelineTransport { - private readonly Func _responseFactory; - private int _retryCount; + private readonly Func _responseFactory; + private readonly bool _addDelay; public string Id { get; } - public Action? OnSendingRequest { get; set; } - public Action? OnReceivedResponse { get; set; } + public Action? OnSendingRequest { get; set; } + public Action? OnReceivedResponse { get; set; } public MockPipelineTransport(string id, params int[] codes) - : this(id, i => codes[i]) { + Id = id; + var requestIndex = 0; + _responseFactory = _ => { return new MockPipelineResponse(codes[requestIndex++]); }; } - public MockPipelineTransport(string id, Func responseFactory) + public MockPipelineTransport(string id, Func responseFactory, bool enableLogging = false, ILoggerFactory? loggerFactory = null, bool addDelay = false) : base(enableLogging, loggerFactory) { Id = id; _responseFactory = responseFactory; + _addDelay = addDelay; } protected override PipelineMessage CreateMessageCore() { - return new RetriableTransportMessage(); + return new MockPipelineMessage(); } protected override void ProcessCore(PipelineMessage message) { - try - { - Stamp(message, "Transport"); + Stamp(message, "Transport"); - OnSendingRequest?.Invoke(_retryCount); + OnSendingRequest?.Invoke((MockPipelineMessage)message); - if (message is RetriableTransportMessage transportMessage) - { - int status = _responseFactory(_retryCount); - transportMessage.SetResponse(status); - } + ((MockPipelineMessage)message).SetResponse(_responseFactory(message)); - OnReceivedResponse?.Invoke(_retryCount); - } - finally + if (_addDelay) { - _retryCount++; + Task.Delay(TimeSpan.FromSeconds(4)).Wait(); } + + OnReceivedResponse?.Invoke((MockPipelineMessage)message); } - protected override ValueTask ProcessCoreAsync(PipelineMessage message) + protected override async ValueTask ProcessCoreAsync(PipelineMessage message) { - try - { - Stamp(message, "Transport"); + Stamp(message, "Transport"); - OnSendingRequest?.Invoke(_retryCount); + OnSendingRequest?.Invoke((MockPipelineMessage)message); - if (message is RetriableTransportMessage transportMessage) - { - int status = _responseFactory(_retryCount); - transportMessage.SetResponse(status); - } + ((MockPipelineMessage)message).SetResponse(_responseFactory(message)); - OnReceivedResponse?.Invoke(_retryCount); - } - finally + if (_addDelay) { - _retryCount++; + await Task.Delay(TimeSpan.FromSeconds(4)); } - return new ValueTask(); + OnReceivedResponse?.Invoke((MockPipelineMessage)message); } private void Stamp(PipelineMessage message, string prefix) @@ -100,90 +90,4 @@ private void Stamp(PipelineMessage message, string prefix) values.Add($"{prefix}:{Id}"); } - - private class RetriableTransportMessage : PipelineMessage - { - public RetriableTransportMessage() : this(new TransportRequest()) - { - } - - protected internal RetriableTransportMessage(PipelineRequest request) : base(request) - { - } - - public void SetResponse(int status) - { - Response = new RetriableTransportResponse(status); - } - } - - private class TransportRequest : PipelineRequest - { - private Uri? _uri; - private readonly PipelineRequestHeaders _headers; - - public TransportRequest() - { - _headers = new MockRequestHeaders(); - _uri = new Uri("https://www.example.com"); - } - - public override void Dispose() { } - - protected override BinaryContent? ContentCore - { - get => throw new NotImplementedException(); - set => throw new NotImplementedException(); - } - - protected override PipelineRequestHeaders HeadersCore - => _headers; - - protected override string MethodCore - { - get => throw new NotImplementedException(); - set => throw new NotImplementedException(); - } - - protected override Uri? UriCore - { - get => _uri; - set => _uri = value; - } - } - - private class RetriableTransportResponse : PipelineResponse - { - public RetriableTransportResponse(int status) - { - Status = status; - } - - public override int Status { get; } - - public override string ReasonPhrase => throw new NotImplementedException(); - - public override Stream? ContentStream - { - get => null; - set => throw new NotImplementedException(); - } - - public override BinaryData Content => throw new NotImplementedException(); - - protected override PipelineResponseHeaders HeadersCore - => new MockResponseHeaders(); - - public override void Dispose() { } - - public override BinaryData BufferContent(CancellationToken cancellationToken = default) - { - throw new NotImplementedException(); - } - - public override ValueTask BufferContentAsync(CancellationToken cancellationToken = default) - { - throw new NotImplementedException(); - } - } } diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockRequestHeaders.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockRequestHeaders.cs index 045e3a27e4e1..acdbcf7ad2ef 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockRequestHeaders.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockRequestHeaders.cs @@ -30,7 +30,8 @@ public override void Add(string name, string value) public override IEnumerator> GetEnumerator() { - throw new NotImplementedException(); + IEnumerator> enumerator = _headers.GetEnumerator(); + return enumerator; } public override bool Remove(string name) @@ -50,6 +51,15 @@ public override bool TryGetValue(string name, out string? value) public override bool TryGetValues(string name, out IEnumerable? values) { - throw new NotImplementedException(); + bool hasValue = _headers.TryGetValue(name, out string? dictionaryValue); + + if (!hasValue || string.IsNullOrEmpty(dictionaryValue)) + { + values = null; + return false; + } + + values = dictionaryValue.Split(','); + return true; } } diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockResponseHeaders.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockResponseHeaders.cs index f297bcd6334d..1c059863cfdb 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockResponseHeaders.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/MockResponseHeaders.cs @@ -11,9 +11,9 @@ public class MockResponseHeaders : PipelineResponseHeaders { private readonly Dictionary _headers; - public MockResponseHeaders() + public MockResponseHeaders(Dictionary? headers = default) { - _headers = new Dictionary(); + _headers = headers ?? new Dictionary(); } public void SetHeader(string name, string value) @@ -21,7 +21,8 @@ public void SetHeader(string name, string value) public override IEnumerator> GetEnumerator() { - throw new NotImplementedException(); + IEnumerator> enumerator = _headers.GetEnumerator(); + return enumerator; } public override bool TryGetValue(string name, out string? value) diff --git a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/ObservableTransport.cs b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/ObservableTransport.cs index be6496cea946..d60940373472 100644 --- a/sdk/core/System.ClientModel/tests/TestFramework/Mocks/ObservableTransport.cs +++ b/sdk/core/System.ClientModel/tests/TestFramework/Mocks/ObservableTransport.cs @@ -92,42 +92,30 @@ public override void Dispose() protected override BinaryContent? ContentCore { - get => throw new NotImplementedException(); + get => null; set => throw new NotImplementedException(); } protected override PipelineRequestHeaders HeadersCore - => throw new NotImplementedException(); + => new MockRequestHeaders(); - protected override string MethodCore - { - get => throw new NotImplementedException(); - set => throw new NotImplementedException(); - } + protected override string MethodCore { get; set; } = "GET"; - protected override Uri? UriCore - { - get => throw new NotImplementedException(); - set => throw new NotImplementedException(); - } + protected override Uri? UriCore { get; set; } = new Uri("http://example.com"); // For the logging policy } private class TransportResponse : PipelineResponse { public override int Status => 0; - public override string ReasonPhrase => throw new NotImplementedException(); + public override string ReasonPhrase { get; } = string.Empty; - public override Stream? ContentStream - { - get => null; - set => throw new NotImplementedException(); - } + public override Stream? ContentStream { get; set; } - public override BinaryData Content => throw new NotImplementedException(); + public override BinaryData Content { get; } = new BinaryData(new byte[0]); protected override PipelineResponseHeaders HeadersCore - => throw new NotImplementedException(); + => new MockResponseHeaders(); public override void Dispose() { diff --git a/sdk/core/System.ClientModel/tests/internal/Internal/ChangeTrackingStringListTests.cs b/sdk/core/System.ClientModel/tests/internal/Internal/ChangeTrackingStringListTests.cs new file mode 100644 index 000000000000..0e628d249059 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/Internal/ChangeTrackingStringListTests.cs @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Internal; +using System.Collections.Generic; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Internal; + +internal class ChangeTrackingStringListTests +{ + [Test] + public void CanDetectAddChange() + { + ChangeTrackingStringList list = new(["a"]); + + Assert.IsFalse(list.HasChanged); + + list.Add("b"); + + Assert.IsTrue(list.HasChanged); + } + + [Test] + public void CanDetectSetChange() + { + ChangeTrackingStringList list = new(["a"]); + + Assert.IsFalse(list.HasChanged); + + list[0] = "b"; + + Assert.IsTrue(list.HasChanged); + } + + [Test] + public void CanDetectClearChange() + { + ChangeTrackingStringList list = new(["a"]); + + Assert.IsFalse(list.HasChanged); + + list.Clear(); + + Assert.IsTrue(list.HasChanged); + } + + [Test] + public void ClearNotAChangeForEmptyList() + { + ChangeTrackingStringList list = new(); + + Assert.IsFalse(list.HasChanged); + + list.Clear(); + + Assert.IsFalse(list.HasChanged); + } + + [Test] + public void CanDetectInsertChange() + { + ChangeTrackingStringList list = new(["a"]); + + Assert.IsFalse(list.HasChanged); + + list.Insert(0, "b"); + + Assert.IsTrue(list.HasChanged); + } + + [Test] + public void CanDetectRemoveChange() + { + ChangeTrackingStringList list = new(["a"]); + + Assert.IsFalse(list.HasChanged); + + list.Remove("a"); + + Assert.IsTrue(list.HasChanged); + } + + [Test] + public void RemoveNotAChangeIfNotRemoved() + { + ChangeTrackingStringList list = new(["a"]); + + Assert.IsFalse(list.HasChanged); + + bool removed = list.Remove("b"); + + Assert.IsFalse(removed); + Assert.IsFalse(list.HasChanged); + } + + [Test] + public void CanDetectRemoveAtChange() + { + ChangeTrackingStringList list = new(["a"]); + + Assert.IsFalse(list.HasChanged); + + list.RemoveAt(0); + + Assert.IsTrue(list.HasChanged); + } + + [Test] + public void RemoveAtNotAChangeIfNotRemoved() + { + ChangeTrackingStringList list = new(["a"]); + + Assert.IsFalse(list.HasChanged); + + Assert.Throws(() => list.RemoveAt(1)); + + Assert.IsFalse(list.HasChanged); + } + + [Test] + public void CanCreateListFromCollection() + { + List originalList = ["a", "b", "c"]; + ChangeTrackingStringList changeTrackingList = new(originalList); + + Assert.AreEqual(originalList.Count, changeTrackingList.Count); + Assert.IsTrue(changeTrackingList.Contains(originalList[0])); + Assert.IsTrue(changeTrackingList.Contains(originalList[1])); + Assert.IsTrue(changeTrackingList.Contains(originalList[2])); + } + + [Test] + public void CanCreateListFromCollectionAndTrackChanges() + { + List originalList = ["a", "b", "c"]; + ChangeTrackingStringList changeTrackingList = new(originalList); + + changeTrackingList.Add("d"); + + Assert.IsTrue(changeTrackingList.HasChanged); + } + + [Test] + public void CannotModifyFrozenList() + { + ChangeTrackingStringList list = ["a"]; + + list.Add("b"); + list.Add("c"); + + list.Freeze(); + + Assert.Throws(() => list.Add("d")); + Assert.Throws(() => list[0] = "d"); + Assert.Throws(() => list.Clear()); + Assert.Throws(() => list.Insert(0, "d")); + Assert.Throws(() => list.Remove("a")); + Assert.Throws(() => list.RemoveAt(0)); + } +} diff --git a/sdk/core/System.ClientModel/tests/internal/Internal/ClientLoggingOptionsTestsInternal.cs b/sdk/core/System.ClientModel/tests/internal/Internal/ClientLoggingOptionsTestsInternal.cs new file mode 100644 index 000000000000..c464df80abdf --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/Internal/ClientLoggingOptionsTestsInternal.cs @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Internal; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using Microsoft.Extensions.Logging.Abstractions; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Internal +{ + internal class ClientLoggingOptionsTestsInternal + { + private static HashSet s_expectedDefaultAllowedHeaderNames { get; } = [ + "traceparent", + "Accept", + "Cache-Control", + "Connection", + "Content-Length", + "Content-Type", + "Date", + "ETag", + "Expires", + "If-Match", + "If-Modified-Since", + "If-None-Match", + "If-Unmodified-Since", + "Last-Modified", + "Pragma", + "Retry-After", + "Server", + "Transfer-Encoding", + "User-Agent", + "WWW-Authenticate" ]; + private static HashSet s_expectedDefaultAllowedQueryParameters { get; } = ["api-version"]; + + [Test] + public void ValidOptionsAreConsideredValid() + { + List validLoggingOptions = [ + new ClientLoggingOptions(), + new ClientLoggingOptions { EnableLogging = true }, + new ClientLoggingOptions { EnableLogging = false }, + new ClientLoggingOptions { EnableMessageLogging = true }, + new ClientLoggingOptions { EnableMessageLogging = false }, + new ClientLoggingOptions { EnableMessageContentLogging = true }, + new ClientLoggingOptions { EnableMessageContentLogging = false }, + new ClientLoggingOptions { EnableLogging = true, EnableMessageLogging = true }, + new ClientLoggingOptions { EnableLogging = true, EnableMessageLogging = false }, + new ClientLoggingOptions { EnableLogging = false, EnableMessageLogging = false }, + new ClientLoggingOptions { EnableLogging = false, EnableMessageContentLogging = false }, + new ClientLoggingOptions { EnableLogging = true, EnableMessageContentLogging = true }, + new ClientLoggingOptions { EnableLogging = true, EnableMessageContentLogging = false }, + new ClientLoggingOptions { EnableLogging = true, EnableMessageLogging = true, EnableMessageContentLogging = true }, + new ClientLoggingOptions { EnableLogging = true, EnableMessageLogging = true, EnableMessageContentLogging = false }, + new ClientLoggingOptions { EnableLogging = true, EnableMessageLogging = false, EnableMessageContentLogging = false }, + new ClientLoggingOptions { EnableLogging = false, EnableMessageLogging = false, EnableMessageContentLogging = false }, + ]; + + foreach (ClientLoggingOptions options in validLoggingOptions) + { + Assert.DoesNotThrow(() => options.ValidateOptions()); + options.LoggerFactory = NullLoggerFactory.Instance; + options.MessageContentSizeLimit = 15; + Assert.DoesNotThrow(() => options.ValidateOptions()); + } + } + + [Test] + public void InValidOptionsAreConsideredInValid() + { + List validLoggingOptions = [ + new ClientLoggingOptions { EnableLogging = false, EnableMessageLogging = true }, + new ClientLoggingOptions { EnableLogging = false, EnableMessageContentLogging = true }, + new ClientLoggingOptions { EnableMessageLogging = false, EnableMessageContentLogging = true }, + new ClientLoggingOptions { EnableLogging = false, EnableMessageLogging = true, EnableMessageContentLogging = true }, + new ClientLoggingOptions { EnableLogging = false, EnableMessageLogging = true, EnableMessageContentLogging = false }, + new ClientLoggingOptions { EnableLogging = false, EnableMessageLogging = false, EnableMessageContentLogging = true }, + new ClientLoggingOptions { EnableLogging = true, EnableMessageLogging = false, EnableMessageContentLogging = true }, + ]; + + foreach (var options in validLoggingOptions) + { + Assert.Throws(() => options.ValidateOptions()); + options.LoggerFactory = NullLoggerFactory.Instance; + options.MessageContentSizeLimit = 15; + Assert.Throws(() => options.ValidateOptions()); + } + } + + [Test] + public void CanGetDefaultSanitizer() + { + ClientLoggingOptions options = new(); + PipelineMessageSanitizer sanitizer = options.GetPipelineMessageSanitizer(); + + Assert.AreEqual(s_expectedDefaultAllowedQueryParameters, sanitizer._allowedQueryParameters); + Assert.AreEqual(s_expectedDefaultAllowedHeaderNames, sanitizer._allowedHeaders); + } + + [Test] + public void CanGetCustomizedSanitizer() + { + ClientLoggingOptions options = new(); + options.AllowedHeaderNames.Add("Custom-Header"); + options.AllowedQueryParameters.Add("custom-query"); + PipelineMessageSanitizer sanitizer = options.GetPipelineMessageSanitizer(); + + HashSet customAllowedHeaders = new(s_expectedDefaultAllowedHeaderNames) + { + "Custom-Header" + }; + + HashSet customQueryParameters = new(s_expectedDefaultAllowedQueryParameters) + { + "custom-query" + }; + + Assert.AreEqual(customQueryParameters, sanitizer._allowedQueryParameters); + Assert.AreEqual(customAllowedHeaders, sanitizer._allowedHeaders); + } + } +} diff --git a/sdk/core/System.ClientModel/tests/internal/Internal/ContentTypeUtilitiesTests.cs b/sdk/core/System.ClientModel/tests/internal/Internal/ContentTypeUtilitiesTests.cs new file mode 100644 index 000000000000..a8e1e11912b9 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/Internal/ContentTypeUtilitiesTests.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Internal; +using System.Text; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Internal +{ + internal class ContentTypeUtilitiesTests + { + [Theory] + [TestCase(null, false, null)] + [TestCase("text/json", true, "Unicode (UTF-8)")] + [TestCase("text/xml", true, "Unicode (UTF-8)")] + [TestCase("application/json", true, "Unicode (UTF-8)")] + [TestCase("application/xml", true, "Unicode (UTF-8)")] + [TestCase("something/else+json", true, "Unicode (UTF-8)")] + [TestCase("something/else+xml", true, "Unicode (UTF-8)")] + [TestCase("random/thing; charset=utf-8", true, "Unicode (UTF-8)")] + [TestCase("application/json; odata.metadata=minimal", true, "Unicode (UTF-8)")] + [TestCase("application/json; odata.metadata=full", true, "Unicode (UTF-8)")] + [TestCase("application/json; odata.metadata=none", true, "Unicode (UTF-8)")] + [TestCase("application/x-www-form-urlencoded", true, "Unicode (UTF-8)")] + [TestCase("application/x-www-form-urlencoded; charset=utf-8", true, "Unicode (UTF-8)")] + + // No other explicit encoding besides "utf-8" is supported, so falls through to defaulting to "utf-8" based on Content-Type. + [TestCase("application/x-www-form-urlencoded; charset=us-ascii", true, "Unicode (UTF-8)")] + + public void DetectsTextContentTypes(string contentType, bool isText, string expectedEncoding) + { + Assert.AreEqual(isText, ContentTypeUtilities.TryGetTextEncoding(contentType, out Encoding? encoding)); + Assert.AreEqual(encoding?.EncodingName, expectedEncoding); + } + } +} diff --git a/sdk/core/System.ClientModel/tests/internal/Internal/Logging/PipelineMessageHeadersLogValueTests.cs b/sdk/core/System.ClientModel/tests/internal/Internal/Logging/PipelineMessageHeadersLogValueTests.cs new file mode 100644 index 000000000000..e6ec2ab43ffc --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/Internal/Logging/PipelineMessageHeadersLogValueTests.cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.ClientModel.Internal; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using ClientModel.Tests.Mocks; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Internal; + +public class PipelineMessageHeadersLogValueTests +{ + [Test] + public void PipelineMessageHeadersLogValueToStringHidesOnlySensitiveHeaders() + { + MockRequestHeaders requestHeaders = new() + { + { "Sensitive-Header", "SensitiveValue" }, + { "NonSensitive-Header", "NonSensitiveValue" }, + { "Content-Length", "6" } + }; + Dictionary headers = new() + { + { "Sensitive-Header", "SensitiveValue" }, + { "NonSensitive-Header", "NonSensitiveValue" }, + { "Content-Length", "6" } + }; + MockResponseHeaders responseHeaders = new(headers); + + PipelineMessageSanitizer sanitizer = new([], ["NonSensitive-Header"]); + + PipelineMessageHeadersLogValue requestLogValue = new(requestHeaders, sanitizer); + PipelineMessageHeadersLogValue responseLogValue = new(responseHeaders, sanitizer); + + string loggedRequestValue = requestLogValue.ToString(); + string loggedResponseValue = responseLogValue.ToString(); + + Assert.That(loggedRequestValue, Is.Not.Null); + Assert.That(loggedResponseValue, Is.Not.Null); + + Assert.AreEqual(loggedRequestValue, "Sensitive-Header:REDACTED\r\nNonSensitive-Header:NonSensitiveValue\r\nContent-Length:REDACTED\r\n"); + Assert.AreEqual(loggedResponseValue, "Sensitive-Header:REDACTED\r\nNonSensitive-Header:NonSensitiveValue\r\nContent-Length:REDACTED\r\n"); + } +} diff --git a/sdk/core/System.ClientModel/tests/internal/Internal/Logging/PipelineMessageLoggerTests.cs b/sdk/core/System.ClientModel/tests/internal/Internal/Logging/PipelineMessageLoggerTests.cs new file mode 100644 index 000000000000..4d748dfe9812 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/Internal/Logging/PipelineMessageLoggerTests.cs @@ -0,0 +1,1075 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Internal; +using System.ClientModel.Primitives; +using System.ClientModel.Tests.TestFramework; +using System.Collections.Generic; +using System.Diagnostics.Tracing; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using ClientModel.Tests; +using ClientModel.Tests.Mocks; +using Microsoft.Extensions.Logging; +using Newtonsoft.Json.Bson; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Internal; + +// Avoid running these tests in parallel with anything else that's sharing the event source +[NonParallelizable] +public class PipelineMessageLoggerTests : SyncAsyncPolicyTestBase +{ + private const int RequestEvent = 1; + private const int RequestContentEvent = 2; + private const int ResponseEvent = 5; + private const int ResponseContentEvent = 6; + private const int ErrorResponseEvent = 8; + private const int ErrorResponseContentEvent = 9; + private const int ResponseContentBlockEvent = 11; + private const int ErrorResponseContentBlockEvent = 12; + private const int ResponseContentTextEvent = 13; + private const int ErrorResponseContentTextEvent = 14; + private const int ResponseContentTextBlockEvent = 15; + private const int ErrorResponseContentTextBlockEvent = 16; + private const int RequestContentTextEvent = 17; + + private const string LoggingPolicyCategoryName = "System.ClientModel.Primitives.MessageLoggingPolicy"; + private const string SystemClientModelEventSourceName = "System.ClientModel"; + private readonly MockResponseHeaders _defaultHeaders = new(new Dictionary() + { + { "Custom-Response-Header", "custom-response-header-value" }, + { "Date", "4/29/2024" }, + { "ETag", "version1" } + }); + private readonly MockResponseHeaders _defaultTextHeaders = new(new Dictionary() + { + { "Custom-Response-Header", "custom-response-header-value" }, + { "Content-Type", "text/plain" }, + { "Date", "4/29/2024" }, + { "ETag", "version1" } + }); + + public PipelineMessageLoggerTests(bool isAsync) : base(isAsync) + { + } + + #region Unit tests + + [Test] + public void LogsAreLoggedToILoggerAndNotEventSourceWhenILoggerIsProvided() + { + using TestClientEventListener listener = new(); + using TestLoggingFactory factory = new(LogLevel.Debug); + + PipelineMessageLogger messageLogger = new(new PipelineMessageSanitizer([], []), factory); + + MockPipelineRequest request = new() + { + Uri = new Uri("http://example.com/") + }; + MockPipelineResponse response = new(500); + + messageLogger.LogRequest("requestId", request, "assembly"); + messageLogger.LogRequestContent("requestId", [1,2,3], null); + messageLogger.LogRequestContent("requestId", "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogResponse("requestId", response, 1); + messageLogger.LogResponseContent("requestId", [1,2,3], null); + messageLogger.LogResponseContent("requestId", "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogResponseContentBlock("requestId", 1, [1,2,3], null); + messageLogger.LogResponseContentBlock("requestId", 1, "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogErrorResponse("requestId", response, 1); + messageLogger.LogErrorResponseContent("requestId", [1, 2, 3], null); + messageLogger.LogErrorResponseContent("requestId", "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogErrorResponseContentBlock("requestId", 1, [1, 2, 3], null); + messageLogger.LogErrorResponseContentBlock("requestId", 1, "Hello"u8.ToArray(), Encoding.UTF8); // text + + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + + logger.SingleEventById(RequestEvent); + logger.SingleEventById(RequestContentEvent); + logger.SingleEventById(RequestContentTextEvent); + logger.SingleEventById(ResponseEvent); + logger.SingleEventById(ResponseContentEvent); + logger.SingleEventById(ResponseContentTextEvent); + logger.SingleEventById(ResponseContentBlockEvent); + logger.SingleEventById(ResponseContentTextBlockEvent); + logger.SingleEventById(ErrorResponseEvent); + logger.SingleEventById(ErrorResponseContentEvent); + logger.SingleEventById(ErrorResponseContentBlockEvent); + logger.SingleEventById(ErrorResponseContentTextBlockEvent); + + CollectionAssert.IsEmpty(listener.EventData); // Nothing should log to Event Source + } + + [Test] + public void LogsAreNotWrittenToEventSourceWhenILoggerIsProvidedAndLogLevelIsWarning() + { + using TestClientEventListener listener = new(); + using TestLoggingFactory factory = new(LogLevel.Warning); + + PipelineMessageLogger messageLogger = new(new PipelineMessageSanitizer([], []), factory); + + MockPipelineRequest request = new() + { + Uri = new Uri("http://example.com/") + }; + MockPipelineResponse response = new(500); + + messageLogger.LogRequest("requestId", request, "assembly"); + messageLogger.LogRequestContent("requestId", [1, 2, 3], null); + messageLogger.LogRequestContent("requestId", "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogResponse("requestId", response, 1); + messageLogger.LogResponseContent("requestId", [1, 2, 3], null); + messageLogger.LogResponseContent("requestId", "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogResponseContentBlock("requestId", 1, [1, 2, 3], null); + messageLogger.LogResponseContentBlock("requestId", 1, "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogErrorResponse("requestId", response, 1); + messageLogger.LogErrorResponseContent("requestId", [1, 2, 3], null); + messageLogger.LogErrorResponseContent("requestId", "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogErrorResponseContentBlock("requestId", 1, [1, 2, 3], null); + messageLogger.LogErrorResponseContentBlock("requestId", 1, "Hello"u8.ToArray(), Encoding.UTF8); // text + + CollectionAssert.IsEmpty(listener.EventData); // Nothing should log to Event Source + } + + [Test] + public void IsEnabledILogger() + { + using TestLoggingFactory factory = new(LogLevel.Debug); + PipelineMessageLogger messageLogger = new(new PipelineMessageSanitizer([], []), factory); + + Assert.IsTrue(messageLogger.IsEnabled(LogLevel.Debug, EventLevel.Verbose)); + Assert.IsTrue(messageLogger.IsEnabled(LogLevel.Critical, EventLevel.Verbose)); + Assert.IsFalse(messageLogger.IsEnabled(LogLevel.Trace, EventLevel.Warning)); + } + + [Test] + public void EventsAreNotLoggedIfDisabledEventSource() + { + using TestEventListenerWarning listener = new(); + PipelineMessageLogger messageLogger = new(new PipelineMessageSanitizer([], []), null); + MockPipelineRequest request = new() + { + Uri = new Uri("http://example.com/") + }; + MockPipelineResponse response = new(500); + + messageLogger.LogRequest("requestId", request, "assembly"); + messageLogger.LogRequestContent("requestId", [1, 2, 3], null); + messageLogger.LogRequestContent("requestId", "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogResponse("requestId", response, 1); + messageLogger.LogResponseContent("requestId", [1, 2, 3], null); + messageLogger.LogResponseContent("requestId", "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogResponseContentBlock("requestId", 1, [1, 2, 3], null); + messageLogger.LogResponseContentBlock("requestId", 1, "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogErrorResponse("requestId", response, 1); + messageLogger.LogErrorResponseContent("requestId", [1, 2, 3], null); + messageLogger.LogErrorResponseContent("requestId", "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogErrorResponseContentBlock("requestId", 1, [1, 2, 3], null); + messageLogger.LogErrorResponseContentBlock("requestId", 1, "Hello"u8.ToArray(), Encoding.UTF8); // text + + listener.SingleEventById(ErrorResponseEvent); + Assert.AreEqual(1, listener.EventData.Count()); + } + + [Test] + public void EventsAreNotLoggedIfDisabledILogger() + { + using TestLoggingFactory factory = new(LogLevel.Warning); + PipelineMessageLogger messageLogger = new(new PipelineMessageSanitizer([], []), factory); + MockPipelineRequest request = new() + { + Uri = new Uri("http://example.com/") + }; + MockPipelineResponse response = new(500); + + messageLogger.LogRequest("requestId", request, "assembly"); + messageLogger.LogRequestContent("requestId", [1, 2, 3], null); + messageLogger.LogRequestContent("requestId", "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogResponse("requestId", response, 1); + messageLogger.LogResponseContent("requestId", [1, 2, 3], null); + messageLogger.LogResponseContent("requestId", "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogResponseContentBlock("requestId", 1, [1, 2, 3], null); + messageLogger.LogResponseContentBlock("requestId", 1, "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogErrorResponse("requestId", response, 1); + messageLogger.LogErrorResponseContent("requestId", [1, 2, 3], null); + messageLogger.LogErrorResponseContent("requestId", "Hello"u8.ToArray(), Encoding.UTF8); // text + messageLogger.LogErrorResponseContentBlock("requestId", 1, [1, 2, 3], null); + messageLogger.LogErrorResponseContentBlock("requestId", 1, "Hello"u8.ToArray(), Encoding.UTF8); // text + + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + logger.SingleEventById(ErrorResponseEvent); + Assert.AreEqual(1, logger.Logs.Count()); + } + + #endregion + + #region Integration tests + + [Test] + public async Task HeadersAndQueryParametersAreSanitizedInRequestAndResponseEventsEventSource() // Request event and response event sanitize headers + { + using TestClientEventListener listener = new(); + + var mockHeaders = new MockResponseHeaders(new Dictionary { { "Custom-Response-Header", "Improved value" }, { "Secret-Response-Header", "Very secret" } }); + var response = new MockPipelineResponse(200, mockHeaders: mockHeaders); + response.SetContent([6, 7, 8, 9, 0]); + + Dictionary requestHeaders = new() + { + { "Secret-Custom-Header", "secret-value" }, + { "Content-Type", "text/json" } + }; + + Uri requestUri = new("https://contoso.a.io?api-version=5&secret=123"); + + await CreatePipelineAndSendRequest(response, requestContentBytes: [1, 2, 3, 4, 5], requestHeaders: requestHeaders, requestUri: requestUri); + + // Assert that headers on the request are sanitized + + EventWrittenEventArgs log = listener.GetAndValidateSingleEvent(LoggingEventIds.RequestEvent, "Request", EventLevel.Informational, SystemClientModelEventSourceName); + string headers = log.GetProperty("headers"); + StringAssert.Contains($"Date:08/16/2024{Environment.NewLine}", headers); + StringAssert.Contains($"Custom-Header:custom-header-value{Environment.NewLine}", headers); + StringAssert.Contains($"Secret-Custom-Header:REDACTED{Environment.NewLine}", headers); + StringAssert.DoesNotContain("secret-value", headers); + + // Assert that headers on the response are sanitized + + log = listener.GetAndValidateSingleEvent(LoggingEventIds.ResponseEvent, "Response", EventLevel.Informational, SystemClientModelEventSourceName); + headers = log.GetProperty("headers"); + StringAssert.Contains($"Custom-Response-Header:Improved value{Environment.NewLine}", headers); + StringAssert.Contains($"Secret-Response-Header:REDACTED{Environment.NewLine}", headers); + StringAssert.DoesNotContain("Very secret", headers); + } + + [Test] + public async Task HeadersAndQueryParametersAreSanitizedInRequestAndResponseEventsILogger() // Request event and response event sanitize headers + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientLoggingOptions loggingOptions = new() { LoggerFactory = factory }; + + var mockHeaders = new MockResponseHeaders(new Dictionary { { "Custom-Response-Header", "Improved value" }, { "Secret-Response-Header", "Very secret" } }); + var response = new MockPipelineResponse(200, mockHeaders: mockHeaders); + response.SetContent([6, 7, 8, 9, 0]); + + Dictionary requestHeaders = new() + { + { "Secret-Custom-Header", "secret-value" }, + { "Content-Type", "text/json" } + }; + + Uri requestUri = new("https://contoso.a.io?api-version=5&secret=123"); + + await CreatePipelineAndSendRequest(response, requestContentBytes: [1, 2, 3, 4, 5], requestHeaders: requestHeaders, requestUri: requestUri, loggingOptions: loggingOptions); + + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + + // Assert that headers on the request are sanitized + + LoggerEvent log = logger.GetAndValidateSingleEvent(LoggingEventIds.RequestEvent, "Request", LogLevel.Information); + string headers = (log.GetValueFromArguments("headers")).ToString(); + StringAssert.Contains($"Date:08/16/2024{Environment.NewLine}", headers); + StringAssert.Contains($"Custom-Header:custom-header-value{Environment.NewLine}", headers); + StringAssert.Contains($"Secret-Custom-Header:REDACTED{Environment.NewLine}", headers); + StringAssert.DoesNotContain("secret-value", headers); + + // Assert that headers on the response are sanitized + + log = logger.GetAndValidateSingleEvent(LoggingEventIds.ResponseEvent, "Response", LogLevel.Information); + headers = (log.GetValueFromArguments("headers")).ToString(); + StringAssert.Contains($"Custom-Response-Header:Improved value{Environment.NewLine}", headers); + StringAssert.Contains($"Secret-Response-Header:REDACTED{Environment.NewLine}", headers); + StringAssert.DoesNotContain("Very secret", headers); + } + + [Test] + public async Task HeadersAndQueryParametersAreSanitizedInErrorResponseEventEventSource() // Error response event sanitizes headers + { + using TestClientEventListener listener = new(); + + var mockHeaders = new MockResponseHeaders(new Dictionary { { "Custom-Response-Header", "Improved value" }, { "Secret-Response-Header", "Very secret" } }); + var response = new MockPipelineResponse(400, mockHeaders: mockHeaders); + response.SetContent([6, 7, 8, 9, 0]); + + Dictionary requestHeaders = new() + { + { "Secret-Custom-Header", "secret-value" }, + { "Content-Type", "text/json" } + }; + + Uri requestUri = new("https://contoso.a.io?api-version=5&secret=123"); + + await CreatePipelineAndSendRequest(response, requestContentBytes: [1, 2, 3, 4, 5], requestHeaders: requestHeaders, requestUri: requestUri); + + // Assert that headers on the response are sanitized + + EventWrittenEventArgs log = listener.GetAndValidateSingleEvent(LoggingEventIds.ErrorResponseEvent, "ErrorResponse", EventLevel.Warning, SystemClientModelEventSourceName); + string headers = log.GetProperty("headers"); + StringAssert.Contains($"Custom-Response-Header:Improved value{Environment.NewLine}", headers); + StringAssert.Contains($"Secret-Response-Header:REDACTED{Environment.NewLine}", headers); + StringAssert.DoesNotContain("Very Secret", headers); + } + + [Test] + public async Task HeadersAndQueryParametersAreSanitizedInErrorResponseEventILogger() // Error response event sanitizes headers + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientLoggingOptions loggingOptions = new() { LoggerFactory = factory }; + + var mockHeaders = new MockResponseHeaders(new Dictionary { { "Custom-Response-Header", "Improved value" }, { "Secret-Response-Header", "Very secret" } }); + var response = new MockPipelineResponse(400, mockHeaders: mockHeaders); + response.SetContent([6, 7, 8, 9, 0]); + + Dictionary requestHeaders = new() + { + { "Secret-Custom-Header", "secret-value" }, + { "Content-Type", "text/json" } + }; + + Uri requestUri = new("https://contoso.a.io?api-version=5&secret=123"); + + await CreatePipelineAndSendRequest(response, requestContentBytes: [1, 2, 3, 4, 5], requestHeaders: requestHeaders, requestUri: requestUri, loggingOptions: loggingOptions); + + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + + // Assert that headers on the response are sanitized + + LoggerEvent log = logger.GetAndValidateSingleEvent(LoggingEventIds.ErrorResponseEvent, "ErrorResponse", LogLevel.Warning); + string headers = (log.GetValueFromArguments("headers")).ToString(); + StringAssert.Contains($"Custom-Response-Header:Improved value{Environment.NewLine}", headers); + StringAssert.Contains($"Secret-Response-Header:REDACTED{Environment.NewLine}", headers); + StringAssert.DoesNotContain("Very Secret", headers); + } + + [Test] + public async Task HeadersAndQueryParametersAreNotSanitizedWhenStarsEventSource() + { + using TestClientEventListener listener = new(); + + var mockHeaders = new MockResponseHeaders(new Dictionary { { "Custom-Response-Header", "Improved value" }, { "Secret-Response-Header", "Very secret" } }); + var response = new MockPipelineResponse(200, mockHeaders: mockHeaders); + response.SetContent([6, 7, 8, 9, 0]); + + ClientLoggingOptions loggingOptions = new(); + loggingOptions.AllowedQueryParameters.Add("*"); + loggingOptions.AllowedHeaderNames.Add("*"); + + Uri requestUri = new("https://contoso.a.io?api-version=5&secret=123"); + + Dictionary requestHeaders = new() + { + { "Secret-Custom-Header", "Value" }, + { "Content-Type", "text/json" } + }; + + await CreatePipelineAndSendRequest(response, loggingOptions, requestContentBytes: [1, 2, 3, 4, 5], requestHeaders: requestHeaders, requestUri: requestUri); + + EventWrittenEventArgs log = listener.GetAndValidateSingleEvent(LoggingEventIds.RequestEvent, "Request", EventLevel.Informational, SystemClientModelEventSourceName); + string headers = log.GetProperty("headers"); + StringAssert.Contains($"Date:08/16/2024{Environment.NewLine}", headers); + StringAssert.Contains($"Custom-Header:Value{Environment.NewLine}", headers); + StringAssert.Contains($"Secret-Custom-Header:Value{Environment.NewLine}", headers); + + log = listener.GetAndValidateSingleEvent(LoggingEventIds.ResponseEvent, "Response", EventLevel.Informational, SystemClientModelEventSourceName); + headers = log.GetProperty("headers"); + StringAssert.Contains($"Custom-Response-Header:Improved value{Environment.NewLine}", headers); + StringAssert.Contains($"Secret-Response-Header:Very secret{Environment.NewLine}", headers); + } + + [Test] + public async Task HeadersAndQueryParametersAreNotSanitizedWhenStarsILogger() + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientLoggingOptions loggingOptions = new() { LoggerFactory = factory }; + + var mockHeaders = new MockResponseHeaders(new Dictionary { { "Custom-Response-Header", "Improved value" }, { "Secret-Response-Header", "Very secret" } }); + var response = new MockPipelineResponse(200, mockHeaders: mockHeaders); + response.SetContent([6, 7, 8, 9, 0]); + + loggingOptions.AllowedQueryParameters.Add("*"); + loggingOptions.AllowedHeaderNames.Add("*"); + + Uri requestUri = new("https://contoso.a.io?api-version=5&secret=123"); + + Dictionary requestHeaders = new() + { + { "Secret-Custom-Header", "Value" }, + { "Content-Type", "text/json" } + }; + + await CreatePipelineAndSendRequest(response, requestContentBytes: [1, 2, 3, 4, 5], requestHeaders: requestHeaders, requestUri: requestUri, loggingOptions: loggingOptions); + + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + + LoggerEvent log = logger.GetAndValidateSingleEvent(LoggingEventIds.RequestEvent, "Request", LogLevel.Information); + string headers = (log.GetValueFromArguments("headers")).ToString(); + StringAssert.Contains($"Date:08/16/2024{Environment.NewLine}", headers); + StringAssert.Contains($"Custom-Header:Value{Environment.NewLine}", headers); + StringAssert.Contains($"Secret-Custom-Header:Value{Environment.NewLine}", headers); + + log = logger.GetAndValidateSingleEvent(LoggingEventIds.ResponseEvent, "Response", LogLevel.Information); + headers = (log.GetValueFromArguments("headers")).ToString(); + StringAssert.Contains($"Custom-Response-Header:Improved value{Environment.NewLine}", headers); + StringAssert.Contains($"Secret-Response-Header:Very secret{Environment.NewLine}", headers); + } + + [Test] + public async Task SendingARequestProducesRequestAndResponseLogMessagesEventSource() // RequestEvent, ResponseEvent + { + using TestClientEventListener listener = new(); + + byte[] requestContent = [1, 2, 3, 4, 5]; + byte[] responseContent = [6, 7, 8, 9, 0]; + + MockPipelineResponse response = new(200, mockHeaders: _defaultHeaders); + response.SetContent(responseContent); + + await CreatePipelineAndSendRequest(response, requestContentBytes: requestContent); + + // Assert that the request log message is written and formatted correctly + + EventWrittenEventArgs log = listener.GetAndValidateSingleEvent(LoggingEventIds.RequestEvent, "Request", EventLevel.Informational, SystemClientModelEventSourceName); + Assert.AreEqual("http://example.com/", log.GetProperty("uri")); + Assert.AreEqual("GET", log.GetProperty("method")); + StringAssert.Contains($"Date:08/16/2024{Environment.NewLine}", log.GetProperty("headers")); + StringAssert.Contains($"Custom-Header:custom-header-value{Environment.NewLine}", log.GetProperty("headers")); + + // Assert that the response log message is written and formatted correctly + + log = listener.GetAndValidateSingleEvent(LoggingEventIds.ResponseEvent, "Response", EventLevel.Informational, SystemClientModelEventSourceName); + Assert.AreEqual(log.GetProperty("status"), 200); + StringAssert.Contains($"Custom-Response-Header:custom-response-header-value{Environment.NewLine}", log.GetProperty("headers")); + + // Assert that no other log messages were written + Assert.AreEqual(2, listener.EventData.Count()); + } + + [Test] + public async Task SendingARequestProducesRequestAndResponseLogMessagesILogger() // RequestEvent, ResponseEvent + { + using TestLoggingFactory factory = new(LogLevel.Debug); + + byte[] requestContent = [1, 2, 3, 4, 5]; + byte[] responseContent = [6, 7, 8, 9, 0]; + + MockPipelineResponse response = new(200, mockHeaders: _defaultHeaders); + response.SetContent(responseContent); + + ClientLoggingOptions loggingOptions = new() + { + LoggerFactory = factory + }; + + await CreatePipelineAndSendRequest(response, requestContentBytes: requestContent, loggingOptions: loggingOptions); + + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + + // Assert that the request log message is written and formatted correctly + + LoggerEvent log = logger.GetAndValidateSingleEvent(LoggingEventIds.RequestEvent, "Request", LogLevel.Information); + Assert.AreEqual("http://example.com/", log.GetValueFromArguments("uri")); + Assert.AreEqual("GET", log.GetValueFromArguments("method")); + StringAssert.Contains($"Date:08/16/2024{Environment.NewLine}", (log.GetValueFromArguments("headers")).ToString()); + StringAssert.Contains($"Custom-Header:custom-header-value{Environment.NewLine}", (log.GetValueFromArguments("headers")).ToString()); + + // Assert that the response log message is written and formatted correctly + + log = logger.GetAndValidateSingleEvent(LoggingEventIds.ResponseEvent, "Response", LogLevel.Information); + Assert.AreEqual(log.GetValueFromArguments("status"), 200); + StringAssert.Contains($"Custom-Response-Header:custom-response-header-value{Environment.NewLine}", (log.GetValueFromArguments("headers")).ToString()); + + // Assert that no other log messages were written + Assert.AreEqual(logger.Logs.Count(), 2); + } + + [Test] + public async Task ReceivingAnErrorResponseProducesAnErrorResponseLogMessageEventSource() // ErrorResponseEvent, ErrorResponseContentEvent + { + using TestClientEventListener listener = new(); + + byte[] responseContent = [6, 7, 8, 9, 0]; + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true, + MessageContentSizeLimit = int.MaxValue + }; + + MockPipelineResponse response = new(400, mockHeaders: _defaultHeaders); + response.SetContent(responseContent); + + await CreatePipelineAndSendRequest(response, loggingOptions, requestContentBytes: [1, 2, 3, 4, 5]); + + // Assert that the error response log message is written and formatted correctly + + EventWrittenEventArgs log = listener.GetAndValidateSingleEvent(LoggingEventIds.ErrorResponseEvent, "ErrorResponse", EventLevel.Warning, SystemClientModelEventSourceName); + Assert.AreEqual(log.GetProperty("status"), 400); + StringAssert.Contains($"Custom-Response-Header:custom-response-header-value{Environment.NewLine}", log.GetProperty("headers")); + + // Assert that the error response content log message is written and formatted correctly + + log = listener.GetAndValidateSingleEvent(LoggingEventIds.ErrorResponseContentEvent, "ErrorResponseContent", EventLevel.Informational, SystemClientModelEventSourceName); + CollectionAssert.AreEqual(responseContent, log.GetProperty("content")); + } + + [Test] + public async Task ReceivingAnErrorResponseProducesAnErrorResponseLogMessageILogger() // ErrorResponseEvent, ErrorResponseContentEvent + { + using TestLoggingFactory factory = new(LogLevel.Debug); + + byte[] responseContent = [6, 7, 8, 9, 0]; + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true, + MessageContentSizeLimit = int.MaxValue, + LoggerFactory = factory + }; + + MockPipelineResponse response = new(400, mockHeaders: _defaultHeaders); + response.SetContent(responseContent); + + await CreatePipelineAndSendRequest(response, requestContentBytes: [1, 2, 3, 4, 5], loggingOptions: loggingOptions); + + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + + // Assert that the error response log message is written and formatted correctly + + LoggerEvent log = logger.GetAndValidateSingleEvent(LoggingEventIds.ErrorResponseEvent, "ErrorResponse", LogLevel.Warning); + Assert.AreEqual(log.GetValueFromArguments("status"), 400); + StringAssert.Contains($"Custom-Response-Header:custom-response-header-value{Environment.NewLine}", (log.GetValueFromArguments("headers")).ToString()); + + // Assert that the error response content log message is written and formatted correctly + + log = logger.GetAndValidateSingleEvent(LoggingEventIds.ErrorResponseContentEvent, "ErrorResponseContent", LogLevel.Information); + CollectionAssert.AreEqual(responseContent, log.GetValueFromArguments("content")); + } + + [Test] + public async Task ContentLoggingEnabledProducesRequestContentAndResponseContentLogMessageEventSource() // RequestContentEvent, ResponseContentEvent + { + using TestClientEventListener listener = new(); + + byte[] requestContent = [1, 2, 3, 4, 5]; + byte[] responseContent = [6, 7, 8, 9, 0]; + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true, + MessageContentSizeLimit = int.MaxValue + }; + + MockPipelineResponse response = new(200, mockHeaders: _defaultHeaders); + response.SetContent(responseContent); + + await CreatePipelineAndSendRequest(response, loggingOptions, requestContentBytes: requestContent); + + // Assert that the request content log message is written and formatted correctly + + EventWrittenEventArgs log = listener.GetAndValidateSingleEvent(LoggingEventIds.RequestContentEvent, "RequestContent", EventLevel.Verbose, SystemClientModelEventSourceName); + Assert.AreEqual(requestContent, log.GetProperty("content")); + + // Assert that the response content log message is written and formatted correctly + + log = listener.GetAndValidateSingleEvent(LoggingEventIds.ResponseContentEvent, "ResponseContent", EventLevel.Verbose, SystemClientModelEventSourceName); + Assert.AreEqual(responseContent, log.GetProperty("content")); + + // Assert content was not written as text + + CollectionAssert.IsEmpty(listener.EventsById(LoggingEventIds.RequestContentTextEvent)); + CollectionAssert.IsEmpty(listener.EventsById(LoggingEventIds.ResponseContentTextEvent)); + } + + [Test] + public async Task ContentLoggingEnabledProducesRequestContentAndResponseContentLogMessageILogger() // RequestContentEvent, ResponseContentEvent + { + using TestLoggingFactory factory = new(LogLevel.Debug); + + byte[] requestContent = [1, 2, 3, 4, 5]; + byte[] responseContent = [6, 7, 8, 9, 0]; + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true, + MessageContentSizeLimit = int.MaxValue, + LoggerFactory = factory + }; + + MockPipelineResponse response = new(200, mockHeaders: _defaultHeaders); + response.SetContent(responseContent); + + await CreatePipelineAndSendRequest(response, requestContentBytes: requestContent, loggingOptions: loggingOptions); + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + + // Assert that the request content log message is written and formatted correctly + + LoggerEvent log = logger.GetAndValidateSingleEvent(LoggingEventIds.RequestContentEvent, "RequestContent", LogLevel.Debug); + Assert.AreEqual(requestContent, log.GetValueFromArguments("content")); + + // Assert that the response content log message is written and formatted correctly + + log = logger.GetAndValidateSingleEvent(LoggingEventIds.ResponseContentEvent, "ResponseContent", LogLevel.Debug); + Assert.AreEqual(responseContent, log.GetValueFromArguments("content")); + + // Assert content was not written as text + + CollectionAssert.IsEmpty(logger.EventsById(LoggingEventIds.RequestContentTextEvent)); + CollectionAssert.IsEmpty(logger.EventsById(LoggingEventIds.ResponseContentTextEvent)); + } + + [Test] + public async Task ContentLoggingEnabledProducesRequestContentAsTextAndResponseContentAsTextEventSource() // RequestContentTextEvent, ResponseContentTextEvent + { + using TestClientEventListener listener = new(); + + string requestContent = "Hello"; + string responseContent = "World!"; + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true, + MessageContentSizeLimit = int.MaxValue + }; + + MockPipelineResponse response = new(200, mockHeaders: _defaultTextHeaders); + response.SetContent(responseContent); + + await CreatePipelineAndSendRequest(response, loggingOptions, requestContentString: requestContent); + + // Assert that the request content text event is written and formatted correctly + + EventWrittenEventArgs log = listener.GetAndValidateSingleEvent(LoggingEventIds.RequestContentTextEvent, "RequestContentText", EventLevel.Verbose, SystemClientModelEventSourceName); + Assert.AreEqual(requestContent, log.GetProperty("content")); + + // Assert that the response content text event is written and formatted correctly + + log = listener.GetAndValidateSingleEvent(LoggingEventIds.ResponseContentTextEvent, "ResponseContentText", EventLevel.Verbose, SystemClientModelEventSourceName); + Assert.AreEqual(responseContent, log.GetProperty("content")); + + // Assert content was not written not as text + + CollectionAssert.IsEmpty(listener.EventsById(LoggingEventIds.RequestContentEvent)); + CollectionAssert.IsEmpty(listener.EventsById(LoggingEventIds.ResponseContentEvent)); + } + + [Test] + public async Task ContentLoggingEnabledProducesRequestContentAsTextAndResponseContentAsTextILogger() // RequestContentTextEvent, ResponseContentTextEvent + { + using TestLoggingFactory factory = new(LogLevel.Debug); + + string requestContent = "Hello"; + string responseContent = "World!"; + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true, + MessageContentSizeLimit = int.MaxValue, + LoggerFactory = factory + }; + + MockPipelineResponse response = new(200, mockHeaders: _defaultTextHeaders); + response.SetContent(responseContent); + + await CreatePipelineAndSendRequest(response, requestContentString: requestContent, loggingOptions: loggingOptions); + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + + // Assert that the request content text event is written and formatted correctly + + LoggerEvent log = logger.GetAndValidateSingleEvent(LoggingEventIds.RequestContentTextEvent, "RequestContentText", LogLevel.Debug); + Assert.AreEqual(requestContent, log.GetValueFromArguments("content")); + + // Assert that the response content text event is written and formatted correctly + + log = logger.GetAndValidateSingleEvent(LoggingEventIds.ResponseContentTextEvent, "ResponseContentText", LogLevel.Debug); + Assert.AreEqual(responseContent, log.GetValueFromArguments("content")); + + // Assert content was not written not as text + + CollectionAssert.IsEmpty(logger.EventsById(LoggingEventIds.RequestContentEvent)); + CollectionAssert.IsEmpty(logger.EventsById(LoggingEventIds.ResponseContentEvent)); + } + + [Test] + public async Task ContentLoggingEnabledProducesResponseContentAsTextWithSeekableTextStreamEventSource() // ResponseContentTextEvent + { + using TestClientEventListener listener = new(); + + await CreatePipelineAndSendRequestWithStreamingResponse(200, true, _defaultTextHeaders, new ClientLoggingOptions(), int.MaxValue); + + EventWrittenEventArgs logEvent = listener.GetAndValidateSingleEvent(LoggingEventIds.ResponseContentTextEvent, "ResponseContentText", EventLevel.Verbose, SystemClientModelEventSourceName); + Assert.AreEqual("Hello world", logEvent.GetProperty("content")); + } + + [Test] + public async Task ContentLoggingEnabledProducesResponseContentAsTextWithSeekableTextStreamILogger() // ResponseContentTextEvent + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientLoggingOptions loggingOptions = new() { LoggerFactory = factory }; + + await CreatePipelineAndSendRequestWithStreamingResponse(200, true, _defaultTextHeaders, loggingOptions, int.MaxValue); + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + + LoggerEvent logEvent = logger.GetAndValidateSingleEvent(LoggingEventIds.ResponseContentTextEvent, "ResponseContentText", LogLevel.Debug); + Assert.AreEqual("Hello world", logEvent.GetValueFromArguments("content")); + } + + [Test] + public async Task ContentLoggingEnabledProducesErrorResponseContentAsTextWithSeekableTextStreamEventSource() // ErrorResponseContentTextEvent + { + using TestClientEventListener listener = new(); + + await CreatePipelineAndSendRequestWithStreamingResponse(500, true, _defaultTextHeaders, new ClientLoggingOptions(), 5); + + EventWrittenEventArgs logEvent = listener.GetAndValidateSingleEvent(LoggingEventIds.ErrorResponseContentTextEvent, "ErrorResponseContentText", EventLevel.Informational, SystemClientModelEventSourceName); + Assert.AreEqual("Hello", logEvent.GetProperty("content")); + } + + [Test] + public async Task ContentLoggingEnabledProducesErrorResponseContentAsTextWithSeekableTextStreamILogger() // ErrorResponseContentTextEvent + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientLoggingOptions loggingOptions = new() { LoggerFactory = factory }; + + await CreatePipelineAndSendRequestWithStreamingResponse(500, true, _defaultTextHeaders, loggingOptions, 5); + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + + LoggerEvent logEvent = logger.GetAndValidateSingleEvent(LoggingEventIds.ErrorResponseContentTextEvent, "ErrorResponseContentText", LogLevel.Information); + Assert.AreEqual("Hello", logEvent.GetValueFromArguments("content")); + } + + [Test] + public async Task NonSeekableResponsesAreLoggedInBlocksEventSource() // ResponseContentBlockEvent + { + using TestClientEventListener listener = new(); + + await CreatePipelineAndSendRequestWithStreamingResponse(200, false, _defaultHeaders, new ClientLoggingOptions()); + + EventWrittenEventArgs[] contentEvents = listener.EventsById(LoggingEventIds.ResponseContentBlockEvent).ToArray(); + + Assert.AreEqual(2, contentEvents.Length); + + Assert.AreEqual(EventLevel.Verbose, contentEvents[0].Level); + Assert.AreEqual("ResponseContentBlock", contentEvents[0].EventName); + Assert.AreEqual(0, contentEvents[0].GetProperty("blockNumber")); + Assert.AreEqual(SystemClientModelEventSourceName, contentEvents[0].EventSource.Name); + CollectionAssert.AreEqual("Hello "u8.ToArray(), contentEvents[0].GetProperty("content")); + + Assert.AreEqual(EventLevel.Verbose, contentEvents[1].Level); + Assert.AreEqual("ResponseContentBlock", contentEvents[1].EventName); + Assert.AreEqual(1, contentEvents[1].GetProperty("blockNumber")); + Assert.AreEqual(SystemClientModelEventSourceName, contentEvents[1].EventSource.Name); + CollectionAssert.AreEqual("world"u8.ToArray(), contentEvents[1].GetProperty("content")); + + CollectionAssert.IsEmpty(listener.EventsById(LoggingEventIds.ResponseContentEvent)); + } + + [Test] + public async Task NonSeekableResponsesAreLoggedInBlocksILogger() // ResponseContentBlockEvent + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientLoggingOptions loggingOptions = new() { LoggerFactory = factory }; + + await CreatePipelineAndSendRequestWithStreamingResponse(200, false, _defaultHeaders, loggingOptions); + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + + LoggerEvent[] contentEvents = logger.EventsById(LoggingEventIds.ResponseContentBlockEvent).ToArray(); + + Assert.AreEqual(2, contentEvents.Length); + + Assert.AreEqual(LogLevel.Debug, contentEvents[0].LogLevel); + Assert.AreEqual("ResponseContentBlock", contentEvents[0].EventId.Name); + Assert.AreEqual(0, contentEvents[0].GetValueFromArguments("blockNumber")); + CollectionAssert.AreEqual("Hello "u8.ToArray(), contentEvents[0].GetValueFromArguments("content")); + + Assert.AreEqual(LogLevel.Debug, contentEvents[1].LogLevel); + Assert.AreEqual("ResponseContentBlock", contentEvents[1].EventId.Name); + Assert.AreEqual(1, contentEvents[1].GetValueFromArguments("blockNumber")); + CollectionAssert.AreEqual("world"u8.ToArray(), contentEvents[1].GetValueFromArguments("content")); + + CollectionAssert.IsEmpty(logger.EventsById(LoggingEventIds.ResponseContentEvent)); + } + + [Test] + public async Task NonSeekableResponseErrorsAreLoggedInBlocksEventSource() // ErrorResponseContentBlockEvent + { + using TestClientEventListener listener = new(); + + await CreatePipelineAndSendRequestWithStreamingResponse(500, false, _defaultHeaders, new ClientLoggingOptions()); + + EventWrittenEventArgs[] errorContentEvents = listener.EventsById(LoggingEventIds.ErrorResponseContentBlockEvent).ToArray(); + + Assert.AreEqual(2, errorContentEvents.Length); + + Assert.AreEqual(EventLevel.Informational, errorContentEvents[0].Level); + Assert.AreEqual("ErrorResponseContentBlock", errorContentEvents[0].EventName); + Assert.AreEqual(0, errorContentEvents[0].GetProperty("blockNumber")); + Assert.AreEqual(SystemClientModelEventSourceName, errorContentEvents[0].EventSource.Name); + CollectionAssert.AreEqual("Hello "u8.ToArray(), errorContentEvents[0].GetProperty("content")); + + Assert.AreEqual(EventLevel.Informational, errorContentEvents[1].Level); + Assert.AreEqual("ErrorResponseContentBlock", errorContentEvents[1].EventName); + Assert.AreEqual(1, errorContentEvents[1].GetProperty("blockNumber")); + Assert.AreEqual(SystemClientModelEventSourceName, errorContentEvents[1].EventSource.Name); + CollectionAssert.AreEqual("world"u8.ToArray(), errorContentEvents[1].GetProperty("content")); + + CollectionAssert.IsEmpty(listener.EventsById(LoggingEventIds.ErrorResponseContentEvent)); + } + + [Test] + public async Task NonSeekableResponsesErrorsAreLoggedInBlocksILogger() // ErrorResponseContentBlockEvent + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientLoggingOptions loggingOptions = new() { LoggerFactory = factory }; + + await CreatePipelineAndSendRequestWithStreamingResponse(500, false, _defaultHeaders, loggingOptions); + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + + LoggerEvent[] errorContentEvents = logger.EventsById(LoggingEventIds.ErrorResponseContentBlockEvent).ToArray(); + + Assert.AreEqual(2, errorContentEvents.Length); + + Assert.AreEqual(LogLevel.Information, errorContentEvents[0].LogLevel); + Assert.AreEqual("ErrorResponseContentBlock", errorContentEvents[0].EventId.Name); + Assert.AreEqual(0, errorContentEvents[0].GetValueFromArguments("blockNumber")); + CollectionAssert.AreEqual("Hello "u8.ToArray(), errorContentEvents[0].GetValueFromArguments("content")); + + Assert.AreEqual(LogLevel.Information, errorContentEvents[1].LogLevel); + Assert.AreEqual("ErrorResponseContentBlock", errorContentEvents[1].EventId.Name); + Assert.AreEqual(1, errorContentEvents[1].GetValueFromArguments("blockNumber")); + CollectionAssert.AreEqual("world"u8.ToArray(), errorContentEvents[1].GetValueFromArguments("content")); + + CollectionAssert.IsEmpty(logger.EventsById(LoggingEventIds.ErrorResponseContentEvent)); + } + + [Test] + public async Task NonSeekableResponsesAreLoggedInTextBlocksEventSource() // ResponseContentTextBlockEvent + { + using TestClientEventListener listener = new(); + + await CreatePipelineAndSendRequestWithStreamingResponse(200, false, _defaultTextHeaders, new ClientLoggingOptions()); + + EventWrittenEventArgs[] contentEvents = listener.EventsById(LoggingEventIds.ResponseContentTextBlockEvent).ToArray(); + + Assert.AreEqual(2, contentEvents.Length); + + Assert.AreEqual(EventLevel.Verbose, contentEvents[0].Level); + + Assert.AreEqual("ResponseContentTextBlock", contentEvents[0].EventName); + Assert.AreEqual(0, contentEvents[0].GetProperty("blockNumber")); + Assert.AreEqual("Hello ", contentEvents[0].GetProperty("content")); + Assert.AreEqual(SystemClientModelEventSourceName, contentEvents[0].EventSource.Name); + + Assert.AreEqual(EventLevel.Verbose, contentEvents[1].Level); + Assert.AreEqual("ResponseContentTextBlock", contentEvents[1].EventName); + Assert.AreEqual(1, contentEvents[1].GetProperty("blockNumber")); + Assert.AreEqual("world", contentEvents[1].GetProperty("content")); + Assert.AreEqual(SystemClientModelEventSourceName, contentEvents[1].EventSource.Name); + + CollectionAssert.IsEmpty(listener.EventsById(LoggingEventIds.ResponseContentEvent)); + } + + [Test] + public async Task NonSeekableResponsesAreLoggedInTextBlocksILogger() // ResponseContentTextBlockEvent + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientLoggingOptions loggingOptions = new() { LoggerFactory = factory }; + + await CreatePipelineAndSendRequestWithStreamingResponse(200, false, _defaultTextHeaders, loggingOptions); + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + + LoggerEvent[] contentEvents = logger.EventsById(LoggingEventIds.ResponseContentTextBlockEvent).ToArray(); + + Assert.AreEqual(2, contentEvents.Length); + + Assert.AreEqual(LogLevel.Debug, contentEvents[0].LogLevel); + + Assert.AreEqual("ResponseContentTextBlock", contentEvents[0].EventId.Name); + Assert.AreEqual(0, contentEvents[0].GetValueFromArguments("blockNumber")); + Assert.AreEqual("Hello ", contentEvents[0].GetValueFromArguments("content")); + + Assert.AreEqual(LogLevel.Debug, contentEvents[1].LogLevel); + Assert.AreEqual("ResponseContentTextBlock", contentEvents[1].EventId.Name); + Assert.AreEqual(1, contentEvents[1].GetValueFromArguments("blockNumber")); + Assert.AreEqual("world", contentEvents[1].GetValueFromArguments("content")); + + CollectionAssert.IsEmpty(logger.EventsById(LoggingEventIds.ResponseContentEvent)); + } + + [Test] + public async Task NonSeekableResponsesErrorsAreLoggedInTextBlocksEventSource() // ErrorResponseContentTextBlockEvent + { + using TestClientEventListener listener = new(); + + await CreatePipelineAndSendRequestWithStreamingResponse(500, false, _defaultTextHeaders, new ClientLoggingOptions()); + + EventWrittenEventArgs[] errorContentEvents = listener.EventsById(LoggingEventIds.ErrorResponseContentTextBlockEvent).ToArray(); + + Assert.AreEqual(2, errorContentEvents.Length); + + Assert.AreEqual(EventLevel.Informational, errorContentEvents[0].Level); + Assert.AreEqual("ErrorResponseContentTextBlock", errorContentEvents[0].EventName); + Assert.AreEqual(0, errorContentEvents[0].GetProperty("blockNumber")); + Assert.AreEqual("Hello ", errorContentEvents[0].GetProperty("content")); + Assert.AreEqual(SystemClientModelEventSourceName, errorContentEvents[0].EventSource.Name); + + Assert.AreEqual(EventLevel.Informational, errorContentEvents[1].Level); + Assert.AreEqual("ErrorResponseContentTextBlock", errorContentEvents[1].EventName); + Assert.AreEqual(1, errorContentEvents[1].GetProperty("blockNumber")); + Assert.AreEqual("world", errorContentEvents[1].GetProperty("content")); + Assert.AreEqual(SystemClientModelEventSourceName, errorContentEvents[1].EventSource.Name); + + CollectionAssert.IsEmpty(listener.EventsById(LoggingEventIds.ErrorResponseContentEvent)); + } + + [Test] + public async Task NonSeekableResponsesErrorsAreLoggedInTextBlocksILogger() // ErrorResponseContentTextBlockEvent + { + using TestLoggingFactory factory = new(LogLevel.Debug); + ClientLoggingOptions loggingOptions = new() { LoggerFactory = factory }; + + await CreatePipelineAndSendRequestWithStreamingResponse(500, false, _defaultTextHeaders, loggingOptions); + TestLogger logger = factory.GetLogger(LoggingPolicyCategoryName); + + LoggerEvent[] errorContentEvents = logger.EventsById(LoggingEventIds.ErrorResponseContentTextBlockEvent).ToArray(); + + Assert.AreEqual(2, errorContentEvents.Length); + + Assert.AreEqual(LogLevel.Information, errorContentEvents[0].LogLevel); + Assert.AreEqual("ErrorResponseContentTextBlock", errorContentEvents[0].EventId.Name); + Assert.AreEqual(0, errorContentEvents[0].GetValueFromArguments("blockNumber")); + Assert.AreEqual("Hello ", errorContentEvents[0].GetValueFromArguments("content")); + + Assert.AreEqual(LogLevel.Information, errorContentEvents[1].LogLevel); + Assert.AreEqual("ErrorResponseContentTextBlock", errorContentEvents[1].EventId.Name); + Assert.AreEqual(1, errorContentEvents[1].GetValueFromArguments("blockNumber")); + Assert.AreEqual("world", errorContentEvents[1].GetValueFromArguments("content")); + + CollectionAssert.IsEmpty(logger.EventsById(LoggingEventIds.ErrorResponseContentEvent)); + } + + #endregion + + #region Helpers + + private class TestEventListenerWarning : TestClientEventListener + { + protected override void OnEventSourceCreated(EventSource eventSource) + { + if (eventSource.Name == "System.ClientModel") + { + Console.WriteLine("Warning"); + EnableEvents(eventSource, EventLevel.Warning); + } + } + } + + private async Task CreatePipelineAndSendRequestWithStreamingResponse(int statusCode, + bool isSeekable, + MockResponseHeaders responseHeaders, + ClientLoggingOptions loggingOptions, + int maxLength = int.MaxValue) + { + MockPipelineResponse response = new(status: statusCode, mockHeaders: responseHeaders); + + byte[] responseContent = Encoding.UTF8.GetBytes("Hello world"); + if (isSeekable) + { + response.ContentStream = new MemoryStream(responseContent); + } + else + { + response.ContentStream = new NonSeekableMemoryStream(responseContent); + } + + loggingOptions.EnableMessageContentLogging = true; + loggingOptions.MessageContentSizeLimit = maxLength; + + // These tests are essentially testing whether the logging policy works + // correctly when responses are buffered (memory stream) and unbuffered + // (non-seekable). In order to validate the intent of the test, we set + // message.BufferResponse accordingly here. + await CreatePipelineAndSendRequest(response, loggingOptions, bufferResponse: isSeekable); + + var buffer = new byte[11]; + + if (IsAsync) + { +#if NET462 + Assert.AreEqual(6, await response.ContentStream.ReadAsync(buffer, 5, 6)); + Assert.AreEqual(5, await response.ContentStream.ReadAsync(buffer, 6, 5)); + Assert.AreEqual(0, await response.ContentStream.ReadAsync(buffer, 0, 5)); +#else + Assert.AreEqual(6, await response.ContentStream.ReadAsync(buffer.AsMemory(5, 6))); + Assert.AreEqual(5, await response.ContentStream.ReadAsync(buffer.AsMemory(6, 5))); + Assert.AreEqual(0, await response.ContentStream.ReadAsync(buffer.AsMemory(0, 5))); +#endif + } + else + { + Assert.AreEqual(6, response.ContentStream.Read(buffer, 5, 6)); + Assert.AreEqual(5, response.ContentStream.Read(buffer, 6, 5)); + Assert.AreEqual(0, response.ContentStream.Read(buffer, 0, 5)); + } + } + + private async Task CreatePipelineAndSendRequest(MockPipelineResponse response, + ClientLoggingOptions? loggingOptions = default, + string? requestContentString = default, + byte[]? requestContentBytes = default, + Dictionary? requestHeaders = default, + Uri? requestUri = default, + bool? bufferResponse = default) + { + ClientLoggingOptions clientLoggingOptions = loggingOptions ?? new ClientLoggingOptions(); + clientLoggingOptions.AllowedHeaderNames.Add("Custom-Header"); + clientLoggingOptions.AllowedHeaderNames.Add("Custom-Response-Header"); + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", i => response), + ClientLoggingOptions = clientLoggingOptions, + RetryPolicy = new ObservablePolicy("RetryPolicy") + }; + + ClientPipeline pipeline = ClientPipeline.Create(options); + + PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = requestUri ?? new Uri("http://example.com"); + + if (requestHeaders != null) + { + foreach (KeyValuePair header in requestHeaders) + { + message.Request.Headers.Add(header.Key, header.Value); + } + } + + message.Request.Headers.Add("Custom-Header", "custom-header-value"); + message.Request.Headers.Add("Date", "08/16/2024"); + + if (bufferResponse != null) + { + message.BufferResponse = bufferResponse.Value; + } + + if (requestContentBytes != null) + { + message.Request.Content = BinaryContent.Create(new BinaryData(requestContentBytes)); + } + else if (requestContentString != null) + { + message.Request.Headers.Add("Content-Type", "text/plain"); + message.Request.Content = BinaryContent.Create(new BinaryData(requestContentString)); + } + + await pipeline.SendSyncOrAsync(message, IsAsync); + } + + #endregion +} diff --git a/sdk/core/System.ClientModel/tests/internal/Internal/Logging/PipelineMessageSanitizerTests.cs b/sdk/core/System.ClientModel/tests/internal/Internal/Logging/PipelineMessageSanitizerTests.cs new file mode 100644 index 000000000000..404f22c59f89 --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/Internal/Logging/PipelineMessageSanitizerTests.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Internal; +using System.Collections.Generic; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Internal +{ + public class PipelineMessageSanitizerTests + { + [Test] + [TestCase("?a", "?a")] + [TestCase("?a=b", "?a=b")] + [TestCase("?a=b&", "?a=b&")] + [TestCase("?d=b&", "?d=*&")] + [TestCase("?d=a", "?d=*")] + [TestCase("?a=b&d", "?a=b&d")] + [TestCase("?a=b&d=1&", "?a=b&d=*&")] + [TestCase("?a=b&d=1&a1", "?a=b&d=*&a1")] + [TestCase("?a=b&d=1&a1=", "?a=b&d=*&a1=")] + [TestCase("?a=b&d=11&a1=&", "?a=b&d=*&a1=&")] + [TestCase("?d&d&d&", "?d&d&d&")] + [TestCase("?a&a&a&a", "?a&a&a&a")] + [TestCase("?&&&&&&&", "?&&&&&&&")] + [TestCase("?d", "?d")] + public void QueryIsSanitized(string input, string expected) + { + var sanitizer = new PipelineMessageSanitizer(["A", "a1", "a-2"], [], "*"); + + Assert.AreEqual("http://localhost/" + expected, sanitizer.SanitizeUrl("http://localhost/" + input)); + } + + [Test] + public void HeaderIsSanitized() + { + var sanitizer = new PipelineMessageSanitizer([], [ "header-1" ], "*"); + + Assert.AreEqual("value1", sanitizer.SanitizeHeader("header-1", "value1")); + Assert.AreEqual("*", sanitizer.SanitizeHeader("header-2", "value2")); + } + + [Test] + public void EverythingIsSanitizedWithNoAllowedHeadersOrQueries() + { + var sanitizer = new PipelineMessageSanitizer([], [], "*"); + + var uri = new Uri("http://localhost/?a=b"); + + Assert.AreEqual("http://localhost/?a=*", sanitizer.SanitizeUrl(uri.ToString())); + Assert.AreEqual("*", sanitizer.SanitizeHeader("header", "value")); + } + } +} diff --git a/sdk/core/System.ClientModel/tests/internal/Internal/Logging/PipelineRetryLoggerTests.cs b/sdk/core/System.ClientModel/tests/internal/Internal/Logging/PipelineRetryLoggerTests.cs new file mode 100644 index 000000000000..f82dda9d73dc --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/Internal/Logging/PipelineRetryLoggerTests.cs @@ -0,0 +1,156 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Internal; +using System.ClientModel.Primitives; +using System.ClientModel.Tests.TestFramework; +using System.Collections.Generic; +using System.Diagnostics.Tracing; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using ClientModel.Tests; +using ClientModel.Tests.Mocks; +using Microsoft.Extensions.Logging; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Internal; + +// Avoid running these tests in parallel with anything else that's sharing the event source +[NonParallelizable] +public class PipelineRetryLoggerTests : SyncAsyncPolicyTestBase +{ + private const string RetryPolicyCategoryName = "System.ClientModel.Primitives.ClientRetryPolicy"; + private const string SystemClientModelEventSourceName = "System-ClientModel"; + private readonly MockResponseHeaders _defaultHeaders = new(new Dictionary() + { + { "Custom-Response-Header", "custom-response-header-value" }, + { "Date", "4/29/2024" }, + { "ETag", "version1" } + }); + private readonly MockResponseHeaders _defaultTextHeaders = new(new Dictionary() + { + { "Custom-Response-Header", "custom-response-header-value" }, + { "Content-Type", "text/plain" }, + { "Date", "4/29/2024" }, + { "ETag", "version1" } + }); + + public PipelineRetryLoggerTests(bool isAsync) : base(isAsync) + { + } + + #region Unit tests + + [Test] + public void RetriesAreLoggedToILoggerAndNotEventSourceWhenILoggerIsProvided() + { + using TestClientEventListener listener = new(); + using TestLoggingFactory factory = new(LogLevel.Debug); + + PipelineRetryLogger retryLogger = new(factory); + retryLogger.LogRequestRetrying("requestId", 1, 1); + + TestLogger logger = factory.GetLogger(RetryPolicyCategoryName); + logger.SingleEventById(10); // RequestRetrying + + CollectionAssert.IsEmpty(listener.EventData); + } + + [Test] + public void RetriesAreLoggedToILoggerAndNotEventSourceWhenILoggerIsProvidedAndLogLevelIsWarning() + { + using TestClientEventListener listener = new(); // Verbose listener + using TestLoggingFactory factory = new(LogLevel.Warning); // Warnings only + + PipelineRetryLogger retryLogger = new(factory); + retryLogger.LogRequestRetrying("requestId", 1, 1); + + CollectionAssert.IsEmpty(listener.EventData); + CollectionAssert.IsEmpty(factory.GetLogger(RetryPolicyCategoryName).Logs); + } + + #endregion + + #region Integration tests + + [Test] + public async Task SendingRequestThatIsRetriedProducesRequestRetryingEventOnEachRetryEventSource() // RequestRetryingEvent + { + using TestClientEventListener listener = new(); + + byte[] requestContent = [1, 2, 3, 4, 5]; + byte[] responseContent = [6, 7, 8, 9, 0]; + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", [429, 429, 200]), + ClientLoggingOptions = new() + }; + ClientPipeline pipeline = ClientPipeline.Create(options); + + PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + message.Request.Content = BinaryContent.Create(new BinaryData(requestContent)); + + await pipeline.SendSyncOrAsync(message, IsAsync); + + EventWrittenEventArgs args = listener.SingleEventById(LoggingEventIds.RequestRetryingEvent, (i => i.GetProperty("retryNumber") == 1)); + Assert.AreEqual("RequestRetrying", args.EventName); + Assert.AreEqual(EventLevel.Informational, args.Level); + Assert.Less(args.GetProperty("seconds"), 1); + + args = listener.SingleEventById(LoggingEventIds.RequestRetryingEvent, (i => i.GetProperty("retryNumber") == 2)); + Assert.AreEqual("RequestRetrying", args.EventName); + Assert.AreEqual(EventLevel.Informational, args.Level); + Assert.Less(args.GetProperty("seconds"), 1); + + // 2 retry logs + 3 request logs + 3 response logs + Assert.AreEqual(8, listener.EventData.Count()); + } + + [Test] + public async Task SendingRequestThatIsRetriedProducesRequestRetryingEventOnEachRetryILogger() // RequestRetryingEvent + { + using TestLoggingFactory factory = new(LogLevel.Debug); + + byte[] requestContent = [1, 2, 3, 4, 5]; + byte[] responseContent = [6, 7, 8, 9, 0]; + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", [429, 429, 200]), + ClientLoggingOptions = new() + { + LoggerFactory = factory + } + }; + ClientPipeline pipeline = ClientPipeline.Create(options); + + using PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + message.Request.Content = BinaryContent.Create(new BinaryData(requestContent)); + + await pipeline.SendSyncOrAsync(message, IsAsync); + TestLogger logger = factory.GetLogger(RetryPolicyCategoryName); + + LoggerEvent log = logger.SingleEventById(LoggingEventIds.RequestRetryingEvent, (i => i.GetValueFromArguments("retryNumber") == 1)); + Assert.AreEqual("RequestRetrying", log.EventId.Name); + Assert.AreEqual(LogLevel.Information, log.LogLevel); + Assert.Less(log.GetValueFromArguments("seconds"), 1); + + log = logger.SingleEventById(LoggingEventIds.RequestRetryingEvent, (i => i.GetValueFromArguments("retryNumber") == 2)); + Assert.AreEqual("RequestRetrying", log.EventId.Name); + Assert.AreEqual(LogLevel.Information, log.LogLevel); + Assert.Less(log.GetValueFromArguments("seconds"), 1); + + // No other logs should have been written to the retry logger + Assert.AreEqual(2, logger.Logs.Count()); + } + + #endregion +} diff --git a/sdk/core/System.ClientModel/tests/internal/Internal/Logging/PipelineTransportLoggerTests.cs b/sdk/core/System.ClientModel/tests/internal/Internal/Logging/PipelineTransportLoggerTests.cs new file mode 100644 index 000000000000..5273da3078fb --- /dev/null +++ b/sdk/core/System.ClientModel/tests/internal/Internal/Logging/PipelineTransportLoggerTests.cs @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.ClientModel.Internal; +using System.ClientModel.Primitives; +using System.ClientModel.Tests.TestFramework; +using System.Collections.Generic; +using System.Diagnostics.Tracing; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Azure.Core.TestFramework; +using ClientModel.Tests; +using ClientModel.Tests.Mocks; +using Microsoft.Extensions.Logging; +using NUnit.Framework; + +namespace System.ClientModel.Tests.Internal; + +// Avoid running these tests in parallel with anything else that's sharing the event source +[NonParallelizable] +public class PipelineTransportLoggerTests : SyncAsyncPolicyTestBase +{ + private const string PipelineTransportCategoryName = "System.ClientModel.Primitives.PipelineTransport"; + private const string SystemClientModelEventSourceName = "System.ClientModel"; + private readonly MockResponseHeaders _defaultHeaders = new(new Dictionary() + { + { "Custom-Response-Header", "custom-response-header-value" }, + { "Date", "4/29/2024" }, + { "ETag", "version1" } + }); + private readonly MockResponseHeaders _defaultTextHeaders = new(new Dictionary() + { + { "Custom-Response-Header", "custom-response-header-value" }, + { "Content-Type", "text/plain" }, + { "Date", "4/29/2024" }, + { "ETag", "version1" } + }); + + public PipelineTransportLoggerTests(bool isAsync) : base(isAsync) + { + } + + #region Unit tests + + [Test] + public void LogsAreLoggedToILoggerAndNotEventSourceWhenILoggerIsProvided() + { + using TestClientEventListener listener = new(); + using TestLoggingFactory factory = new(LogLevel.Debug); + + PipelineTransportLogger transportLogger = new(factory); + + transportLogger.LogExceptionResponse("requestId", new InvalidOperationException()); + transportLogger.LogResponseDelay("requestId", 1); + + TestLogger logger = factory.GetLogger(PipelineTransportCategoryName); + logger.SingleEventById(7); // ResponseDelay + logger.SingleEventById(18); // ExceptionResponse + + CollectionAssert.IsEmpty(listener.EventData); + } + + [Test] + public void LogsAreLoggedToILoggerAndNotEventSourceWhenILoggerIsProvidedAndLogLevelIsWarning() + { + using TestClientEventListener listener = new(); + using TestLoggingFactory factory = new(LogLevel.Warning); + + PipelineTransportLogger transportLogger = new(factory); + + transportLogger.LogExceptionResponse("requestId", new InvalidOperationException()); + transportLogger.LogResponseDelay("requestId", 1); + + TestLogger logger = factory.GetLogger(PipelineTransportCategoryName); + logger.SingleEventById(7); // ResponseDelay + + CollectionAssert.IsEmpty(listener.EventData); + } + + #endregion + + #region Integration tests + + [Test] + public void GettingExceptionResponseProducesEventsEventSource() // ExceptionResponseEvent + { + using TestClientEventListener listener = new(); + + var exception = new InvalidOperationException(); + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", (PipelineMessage i) => throw exception, true, null, false), + ClientLoggingOptions = new() + { + EnableMessageContentLogging = true + } + }; + + ClientPipeline pipeline = ClientPipeline.Create(options); + + PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + message.Request.Headers.Add("User-Agent", "agent"); + + Assert.ThrowsAsync(async () => await pipeline.SendSyncOrAsync(message, IsAsync)); + + EventWrittenEventArgs log = listener.GetAndValidateSingleEvent(LoggingEventIds.ExceptionResponseEvent, "ExceptionResponse", EventLevel.Informational, SystemClientModelEventSourceName); + Assert.AreEqual(exception.ToString().Split(Environment.NewLine.ToCharArray())[0], log.GetProperty("exception").Split(Environment.NewLine.ToCharArray())[0]); + } + + [Test] + public void GettingExceptionResponseProducesEventsILogger() // ExceptionResponseEvent + { + using TestLoggingFactory factory = new(LogLevel.Debug); + var exception = new InvalidOperationException(); + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true, + LoggerFactory = factory + }; + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", (PipelineMessage i) => throw exception, true, factory), + ClientLoggingOptions = loggingOptions + }; + + ClientPipeline pipeline = ClientPipeline.Create(options); + + using PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + message.Request.Headers.Add("User-Agent", "agent"); + + Assert.ThrowsAsync(async () => await pipeline.SendSyncOrAsync(message, IsAsync)); + TestLogger logger = factory.GetLogger(PipelineTransportCategoryName); + + LoggerEvent log = logger.GetAndValidateSingleEvent(LoggingEventIds.ExceptionResponseEvent, "ExceptionResponse", LogLevel.Information); + Assert.AreEqual(exception, log.Exception); + } + + [Test] + public async Task ResponseReceivedAfterThreeSecondsProducesResponseDelayEventEventSource() // ResponseDelayEvent + { + using TestClientEventListener listener = new(); + + byte[] requestContent = [1, 2, 3, 4, 5]; + byte[] responseContent = [6, 7, 8, 9, 0]; + + MockPipelineResponse response = new(200, mockHeaders: _defaultHeaders); + response.SetContent(responseContent); + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", i => response, true, null, true), + ClientLoggingOptions = new(), + RetryPolicy = new ObservablePolicy("RetryPolicy") + }; + + ClientPipeline pipeline = ClientPipeline.Create(options); + + PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + message.Request.Content = BinaryContent.Create(new BinaryData(requestContent)); + + await pipeline.SendSyncOrAsync(message, IsAsync); + + // Assert that the response delay log message is written and formatted correctly + + EventWrittenEventArgs log = listener.GetAndValidateSingleEvent(LoggingEventIds.ResponseDelayEvent, "ResponseDelay", EventLevel.Warning, SystemClientModelEventSourceName); + Assert.Greater(log.GetProperty("seconds"), 3); + } + + [Test] + public async Task ResponseReceivedAfterThreeSecondsProducesResponseDelayEventILogger() // ResponseDelayEvent + { + using TestLoggingFactory factory = new(LogLevel.Debug); + + byte[] requestContent = [1, 2, 3, 4, 5]; + byte[] responseContent = [6, 7, 8, 9, 0]; + + MockPipelineResponse response = new(200, mockHeaders: _defaultHeaders); + response.SetContent(responseContent); + + ClientLoggingOptions loggingOptions = new() + { + EnableMessageContentLogging = true, + LoggerFactory = factory + }; + + ClientPipelineOptions options = new() + { + Transport = new MockPipelineTransport("Transport", i => response, true, factory, true), + ClientLoggingOptions = loggingOptions, + RetryPolicy = new ObservablePolicy("RetryPolicy") + }; + + ClientPipeline pipeline = ClientPipeline.Create(options); + + using PipelineMessage message = pipeline.CreateMessage(); + message.Request.Method = "GET"; + message.Request.Uri = new Uri("http://example.com"); + message.Request.Content = BinaryContent.Create(new BinaryData(requestContent)); + + await pipeline.SendSyncOrAsync(message, IsAsync); + TestLogger logger = factory.GetLogger(PipelineTransportCategoryName); + + // Assert that the response log message is written and formatted correctly + + LoggerEvent log = logger.GetAndValidateSingleEvent(LoggingEventIds.ResponseDelayEvent, "ResponseDelay", LogLevel.Warning); + Assert.Greater(log.GetValueFromArguments("seconds"), 3); + } + + #endregion +}