diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/IngestionPipeline.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/IngestionPipeline.cs index c027b934297..1eeb94058ee 100644 --- a/src/Libraries/Microsoft.Extensions.DataIngestion/IngestionPipeline.cs +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/IngestionPipeline.cs @@ -154,7 +154,7 @@ private async IAsyncEnumerable ProcessAsync(IEnumerable ProcessAsync(IEnumerable IngestAsync(IngestionDocument document, Activity? parentActivity, CancellationToken cancellationToken) { foreach (IngestionDocumentProcessor processor in DocumentProcessors) { @@ -188,5 +189,7 @@ private async Task IngestAsync(IngestionDocument document, Activity? parentActiv _logger?.WritingChunks(GetShortName(_writer)); await _writer.WriteAsync(chunks, cancellationToken).ConfigureAwait(false); _logger?.WroteChunks(document.Identifier); + + return document; } } diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/IngestionResult.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/IngestionResult.cs index 3e325116be3..1a4e57ea3b8 100644 --- a/src/Libraries/Microsoft.Extensions.DataIngestion/IngestionResult.cs +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/IngestionResult.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.IO; using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.DataIngestion; @@ -13,9 +12,9 @@ namespace Microsoft.Extensions.DataIngestion; public sealed class IngestionResult { /// - /// Gets the source file that was ingested. + /// Gets the ID of the document that was ingested. /// - public FileInfo Source { get; } + public string DocumentId { get; } /// /// Gets the ingestion document created from the source file, if reading the document has succeeded. @@ -32,9 +31,9 @@ public sealed class IngestionResult /// public bool Succeeded => Exception is null; - internal IngestionResult(FileInfo source, IngestionDocument? document, Exception? exception) + internal IngestionResult(string documentId, IngestionDocument? document, Exception? exception) { - Source = Throw.IfNull(source); + DocumentId = Throw.IfNullOrEmpty(documentId); Document = document; Exception = exception; } diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Log.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/Log.cs index 0d1ce4daa13..58732e8ead7 100644 --- a/src/Libraries/Microsoft.Extensions.DataIngestion/Log.cs +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Log.cs @@ -31,5 +31,11 @@ internal static partial class Log [LoggerMessage(6, LogLevel.Error, "An error occurred while ingesting document '{identifier}'.")] internal static partial void IngestingFailed(this ILogger logger, Exception exception, string identifier); + + [LoggerMessage(7, LogLevel.Error, "The AI chat service returned {resultCount} instead of {expectedCount} results.")] + internal static partial void UnexpectedResultsCount(this ILogger logger, int resultCount, int expectedCount); + + [LoggerMessage(8, LogLevel.Error, "Unexpected enricher failure.")] + internal static partial void UnexpectedEnricherFailure(this ILogger logger, Exception exception); } } diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Microsoft.Extensions.DataIngestion.csproj b/src/Libraries/Microsoft.Extensions.DataIngestion/Microsoft.Extensions.DataIngestion.csproj index b81b2724f51..acc451c9958 100644 --- a/src/Libraries/Microsoft.Extensions.DataIngestion/Microsoft.Extensions.DataIngestion.csproj +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Microsoft.Extensions.DataIngestion.csproj @@ -15,6 +15,7 @@ + @@ -25,7 +26,6 @@ - diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/ClassificationEnricher.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/ClassificationEnricher.cs index e1cb1ca7438..ad7b7d645d6 100644 --- a/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/ClassificationEnricher.cs +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/ClassificationEnricher.cs @@ -2,13 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Frozen; using System.Collections.Generic; -using System.Runtime.CompilerServices; using System.Text; using System.Threading; -using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.DataIngestion; @@ -21,30 +19,28 @@ namespace Microsoft.Extensions.DataIngestion; /// an optional fallback class for cases where no suitable classification can be determined. public sealed class ClassificationEnricher : IngestionChunkProcessor { - private readonly IChatClient _chatClient; - private readonly ChatOptions? _chatOptions; - private readonly FrozenSet _predefinedClasses; + private readonly EnricherOptions _options; private readonly ChatMessage _systemPrompt; + private readonly ILogger? _logger; /// /// Initializes a new instance of the class. /// - /// The chat client used for classification. + /// The options for the classification enricher. /// The set of predefined classification classes. - /// Options for the chat client. /// The fallback class to use when no suitable classification is found. When not provided, it defaults to "Unknown". - public ClassificationEnricher(IChatClient chatClient, ReadOnlySpan predefinedClasses, - ChatOptions? chatOptions = null, string? fallbackClass = null) + public ClassificationEnricher(EnricherOptions options, ReadOnlySpan predefinedClasses, + string? fallbackClass = null) { - _chatClient = Throw.IfNull(chatClient); - _chatOptions = chatOptions; + _options = Throw.IfNull(options).Clone(); if (string.IsNullOrWhiteSpace(fallbackClass)) { fallbackClass = "Unknown"; } - _predefinedClasses = CreatePredefinedSet(predefinedClasses, fallbackClass!); + Validate(predefinedClasses, fallbackClass!); _systemPrompt = CreateSystemPrompt(predefinedClasses, fallbackClass!); + _logger = _options.LoggerFactory?.CreateLogger(); } /// @@ -53,28 +49,10 @@ public ClassificationEnricher(IChatClient chatClient, ReadOnlySpan prede public static string MetadataKey => "classification"; /// - public override async IAsyncEnumerable> ProcessAsync(IAsyncEnumerable> chunks, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - _ = Throw.IfNull(chunks); - - await foreach (IngestionChunk chunk in chunks.WithCancellation(cancellationToken)) - { - var response = await _chatClient.GetResponseAsync( - [ - _systemPrompt, - new(ChatRole.User, chunk.Content) - ], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); - - chunk.Metadata[MetadataKey] = _predefinedClasses.Contains(response.Text) - ? response.Text - : throw new InvalidOperationException($"Classification returned an unexpected class: '{response.Text}'."); - - yield return chunk; - } - } + public override IAsyncEnumerable> ProcessAsync(IAsyncEnumerable> chunks, CancellationToken cancellationToken = default) + => Batching.ProcessAsync(chunks, _options, MetadataKey, _systemPrompt, _logger, cancellationToken); - private static FrozenSet CreatePredefinedSet(ReadOnlySpan predefinedClasses, string fallbackClass) + private static void Validate(ReadOnlySpan predefinedClasses, string fallbackClass) { if (predefinedClasses.Length == 0) { @@ -84,15 +62,6 @@ private static FrozenSet CreatePredefinedSet(ReadOnlySpan predef HashSet predefinedClassesSet = new(StringComparer.Ordinal) { fallbackClass }; foreach (string predefinedClass in predefinedClasses) { -#if NET - if (predefinedClass.Contains(',', StringComparison.Ordinal)) -#else - if (predefinedClass.IndexOf(',') >= 0) -#endif - { - Throw.ArgumentException(nameof(predefinedClasses), $"Predefined class '{predefinedClass}' must not contain ',' character."); - } - if (!predefinedClassesSet.Add(predefinedClass)) { if (predefinedClass.Equals(fallbackClass, StringComparison.Ordinal)) @@ -103,13 +72,11 @@ private static FrozenSet CreatePredefinedSet(ReadOnlySpan predef Throw.ArgumentException(nameof(predefinedClasses), $"Duplicate class found: '{predefinedClass}'."); } } - - return predefinedClassesSet.ToFrozenSet(); } private static ChatMessage CreateSystemPrompt(ReadOnlySpan predefinedClasses, string fallbackClass) { - StringBuilder sb = new("You are a classification expert. Analyze the given text and assign a single, most relevant class. Use only the following predefined classes: "); + StringBuilder sb = new("You are a classification expert. For each of the following texts, assign a single, most relevant class. Use only the following predefined classes: "); #if NET9_0_OR_GREATER sb.AppendJoin(", ", predefinedClasses!); diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/EnricherOptions.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/EnricherOptions.cs new file mode 100644 index 00000000000..182e07d9c1f --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/EnricherOptions.cs @@ -0,0 +1,54 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.DataIngestion; + +/// +/// Represents options for enrichers that use an AI chat client. +/// +public class EnricherOptions +{ + /// + /// Initializes a new instance of the class. + /// + /// The AI chat client to be used. + public EnricherOptions(IChatClient chatClient) + { + ChatClient = Throw.IfNull(chatClient); + } + + /// + /// Gets the AI chat client to be used. + /// + public IChatClient ChatClient { get; } + + /// + /// Gets or sets the options for the . + /// + public ChatOptions? ChatOptions { get; set; } + + /// + /// Gets or sets the logger factory to be used for logging. + /// + /// + /// Enricher failures should not fail the whole ingestion pipeline, as they are best-effort enhancements. + /// This logger factory can be used to create loggers to log such failures. + /// + public ILoggerFactory? LoggerFactory { get; set; } + + /// + /// Gets or sets the batch size for processing chunks. Default is 20. + /// + public int BatchSize { get; set => field = Throw.IfLessThanOrEqual(value, 0); } = 20; + + internal EnricherOptions Clone() => new(ChatClient) + { + ChatOptions = ChatOptions?.Clone(), + LoggerFactory = LoggerFactory, + BatchSize = BatchSize + }; +} diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/ImageAlternativeTextEnricher.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/ImageAlternativeTextEnricher.cs index 5f68552cc3f..b133e0fa31a 100644 --- a/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/ImageAlternativeTextEnricher.cs +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/ImageAlternativeTextEnricher.cs @@ -2,9 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.DataIngestion; @@ -15,20 +17,19 @@ namespace Microsoft.Extensions.DataIngestion; /// public sealed class ImageAlternativeTextEnricher : IngestionDocumentProcessor { - private readonly IChatClient _chatClient; - private readonly ChatOptions? _chatOptions; + private readonly EnricherOptions _options; private readonly ChatMessage _systemPrompt; + private readonly ILogger? _logger; /// /// Initializes a new instance of the class. /// - /// The chat client used to get responses for generating alternative text. - /// Options for the chat client. - public ImageAlternativeTextEnricher(IChatClient chatClient, ChatOptions? chatOptions = null) + /// The options for generating alternative text. + public ImageAlternativeTextEnricher(EnricherOptions options) { - _chatClient = Throw.IfNull(chatClient); - _chatOptions = chatOptions; - _systemPrompt = new(ChatRole.System, "Write a detailed alternative text for this image with less than 50 words."); + _options = Throw.IfNull(options).Clone(); + _systemPrompt = new(ChatRole.System, "For each of the following images, write a detailed alternative text with fewer than 50 words."); + _logger = _options.LoggerFactory?.CreateLogger(); } /// @@ -36,39 +37,86 @@ public override async Task ProcessAsync(IngestionDocument doc { _ = Throw.IfNull(document); + List? batch = null; + foreach (var element in document.EnumerateContent()) { if (element is IngestionDocumentImage image) { - await ProcessAsync(image, cancellationToken).ConfigureAwait(false); + if (ShouldProcess(image)) + { + batch ??= new(_options.BatchSize); + batch.Add(image); + + if (batch.Count == _options.BatchSize) + { + await ProcessAsync(batch, cancellationToken).ConfigureAwait(false); + batch.Clear(); + } + } } else if (element is IngestionDocumentTable table) { foreach (var cell in table.Cells) { - if (cell is IngestionDocumentImage cellImage) + if (cell is IngestionDocumentImage cellImage && ShouldProcess(cellImage)) { - await ProcessAsync(cellImage, cancellationToken).ConfigureAwait(false); + batch ??= new(_options.BatchSize); + batch.Add(cellImage); + + if (batch.Count == _options.BatchSize) + { + await ProcessAsync(batch, cancellationToken).ConfigureAwait(false); + batch.Clear(); + } } } } } + if (batch?.Count > 0) + { + await ProcessAsync(batch, cancellationToken).ConfigureAwait(false); + } + return document; } - private async Task ProcessAsync(IngestionDocumentImage image, CancellationToken cancellationToken) + private static bool ShouldProcess(IngestionDocumentImage img) => + img.Content.HasValue && !string.IsNullOrEmpty(img.MediaType) && string.IsNullOrEmpty(img.AlternativeText); + + private async Task ProcessAsync(List batch, CancellationToken cancellationToken) { - if (image.Content.HasValue && !string.IsNullOrEmpty(image.MediaType) - && string.IsNullOrEmpty(image.AlternativeText)) + List contents = new(batch.Count); + foreach (var image in batch) { - var response = await _chatClient.GetResponseAsync( - [ - _systemPrompt, - new(ChatRole.User, [new DataContent(image.Content.Value, image.MediaType!)]) - ], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); + contents.Add(new DataContent(image.Content!.Value, image.MediaType!)); + } + + try + { + ChatResponse response = await _options.ChatClient.GetResponseAsync( + [_systemPrompt, new(ChatRole.User, contents)], + _options.ChatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); - image.AlternativeText = response.Text; + if (response.Result.Length == contents.Count) + { + for (int i = 0; i < response.Result.Length; i++) + { + batch[i].AlternativeText = response.Result[i]; + } + } + else + { + _logger?.UnexpectedResultsCount(response.Result.Length, contents.Count); + } + } +#pragma warning disable CA1031 // Do not catch general exception types + catch (Exception ex) +#pragma warning restore CA1031 // Do not catch general exception types + { + // Enricher failures should not fail the whole ingestion pipeline, as they are best-effort enhancements. + _logger?.UnexpectedEnricherFailure(ex); } } } diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/KeywordEnricher.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/KeywordEnricher.cs index 56a305e2a87..c12c805544d 100644 --- a/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/KeywordEnricher.cs +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/KeywordEnricher.cs @@ -2,13 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Frozen; using System.Collections.Generic; -using System.Runtime.CompilerServices; using System.Text; using System.Threading; -using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.DataIngestion; @@ -22,34 +20,26 @@ namespace Microsoft.Extensions.DataIngestion; public sealed class KeywordEnricher : IngestionChunkProcessor { private const int DefaultMaxKeywords = 5; -#if NET - private static readonly System.Buffers.SearchValues _illegalCharacters = System.Buffers.SearchValues.Create([';', ',']); -#else - private static readonly char[] _illegalCharacters = [';', ',']; -#endif - private readonly IChatClient _chatClient; - private readonly ChatOptions? _chatOptions; - private readonly FrozenSet? _predefinedKeywords; + private readonly EnricherOptions _options; private readonly ChatMessage _systemPrompt; + private readonly ILogger? _logger; /// /// Initializes a new instance of the class. /// - /// The chat client used for keyword extraction. + /// The options for generating keywords. /// The set of predefined keywords for extraction. - /// Options for the chat client. /// The maximum number of keywords to extract. When not provided, it defaults to 5. /// The confidence threshold for keyword inclusion. When not provided, it defaults to 0.7. /// /// If no predefined keywords are provided, the model will extract keywords based on the content alone. /// Such results may vary more significantly between different AI models. /// - public KeywordEnricher(IChatClient chatClient, ReadOnlySpan predefinedKeywords, - ChatOptions? chatOptions = null, int? maxKeywords = null, double? confidenceThreshold = null) + public KeywordEnricher(EnricherOptions options, ReadOnlySpan predefinedKeywords, + int? maxKeywords = null, double? confidenceThreshold = null) { - _chatClient = Throw.IfNull(chatClient); - _chatOptions = chatOptions; - _predefinedKeywords = CreatePredfinedKeywords(predefinedKeywords); + _options = Throw.IfNull(options).Clone(); + Validate(predefinedKeywords); double threshold = confidenceThreshold.HasValue ? Throw.IfOutOfRange(confidenceThreshold.Value, 0.0, 1.0, nameof(confidenceThreshold)) @@ -58,6 +48,7 @@ public KeywordEnricher(IChatClient chatClient, ReadOnlySpan predefinedKe ? Throw.IfLessThanOrEqual(maxKeywords.Value, 0, nameof(maxKeywords)) : DefaultMaxKeywords; _systemPrompt = CreateSystemPrompt(keywordsCount, predefinedKeywords, threshold); + _logger = _options.LoggerFactory?.CreateLogger(); } /// @@ -66,70 +57,29 @@ public KeywordEnricher(IChatClient chatClient, ReadOnlySpan predefinedKe public static string MetadataKey => "keywords"; /// - public override async IAsyncEnumerable> ProcessAsync(IAsyncEnumerable> chunks, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - _ = Throw.IfNull(chunks); - - await foreach (IngestionChunk chunk in chunks.WithCancellation(cancellationToken)) - { - // Structured response is not used here because it's not part of Microsoft.Extensions.AI.Abstractions. - var response = await _chatClient.GetResponseAsync( - [ - _systemPrompt, - new(ChatRole.User, chunk.Content) - ], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); - -#pragma warning disable EA0009 // Use 'System.MemoryExtensions.Split' for improved performance - string[] keywords = response.Text.Split(';'); - if (_predefinedKeywords is not null) - { - foreach (var keyword in keywords) - { - if (!_predefinedKeywords.Contains(keyword)) - { - throw new InvalidOperationException($"The extracted keyword '{keyword}' is not in the predefined keywords list."); - } - } - } + public override IAsyncEnumerable> ProcessAsync(IAsyncEnumerable> chunks, CancellationToken cancellationToken = default) + => Batching.ProcessAsync(chunks, _options, MetadataKey, _systemPrompt, _logger, cancellationToken); - chunk.Metadata[MetadataKey] = keywords; - - yield return chunk; - } - } - - private static FrozenSet? CreatePredfinedKeywords(ReadOnlySpan predefinedKeywords) + private static void Validate(ReadOnlySpan predefinedKeywords) { if (predefinedKeywords.Length == 0) { - return null; + return; } HashSet result = new(StringComparer.Ordinal); foreach (string keyword in predefinedKeywords) { -#if NET - if (keyword.AsSpan().ContainsAny(_illegalCharacters)) -#else - if (keyword.IndexOfAny(_illegalCharacters) >= 0) -#endif - { - Throw.ArgumentException(nameof(predefinedKeywords), $"Predefined keyword '{keyword}' contains an invalid character (';' or ',')."); - } - if (!result.Add(keyword)) { Throw.ArgumentException(nameof(predefinedKeywords), $"Duplicate keyword found: '{keyword}'"); } } - - return result.ToFrozenSet(StringComparer.Ordinal); } private static ChatMessage CreateSystemPrompt(int maxKeywords, ReadOnlySpan predefinedKeywords, double confidenceThreshold) { - StringBuilder sb = new($"You are a keyword extraction expert. Analyze the given text and extract up to {maxKeywords} most relevant keywords. "); + StringBuilder sb = new($"You are a keyword extraction expert. For each of the following texts, extract up to {maxKeywords} most relevant keywords. "); if (predefinedKeywords.Length > 0) { @@ -152,7 +102,6 @@ private static ChatMessage CreateSystemPrompt(int maxKeywords, ReadOnlySpan public sealed class SentimentEnricher : IngestionChunkProcessor { - private readonly IChatClient _chatClient; - private readonly ChatOptions? _chatOptions; - private readonly FrozenSet _validSentiments = -#if NET9_0_OR_GREATER - FrozenSet.Create(StringComparer.Ordinal, "Positive", "Negative", "Neutral", "Unknown"); -#else - new string[] { "Positive", "Negative", "Neutral", "Unknown" }.ToFrozenSet(StringComparer.Ordinal); -#endif + private readonly EnricherOptions _options; private readonly ChatMessage _systemPrompt; + private readonly ILogger? _logger; /// /// Initializes a new instance of the class. /// - /// The chat client used for sentiment analysis. - /// Options for the chat client. + /// The options for sentiment analysis. /// The confidence threshold for sentiment determination. When not provided, it defaults to 0.7. - public SentimentEnricher(IChatClient chatClient, ChatOptions? chatOptions = null, double? confidenceThreshold = null) + public SentimentEnricher(EnricherOptions options, double? confidenceThreshold = null) { - _chatClient = Throw.IfNull(chatClient); - _chatOptions = chatOptions; + _options = Throw.IfNull(options).Clone(); double threshold = confidenceThreshold.HasValue ? Throw.IfOutOfRange(confidenceThreshold.Value, 0.0, 1.0, nameof(confidenceThreshold)) : 0.7; string prompt = $""" - You are a sentiment analysis expert. Analyze the sentiment of the given text and return Positive/Negative/Neutral or - Unknown when confidence score is below {threshold}. Return just the value of the sentiment. + You are a sentiment analysis expert. For each of the following texts, analyze the sentiment and return Positive/Negative/Neutral or + Unknown when confidence score is below {threshold}. """; _systemPrompt = new(ChatRole.System, prompt); + _logger = _options.LoggerFactory?.CreateLogger(); } /// @@ -56,27 +47,6 @@ Unknown when confidence score is below {threshold}. Return just the value of the public static string MetadataKey => "sentiment"; /// - public override async IAsyncEnumerable> ProcessAsync(IAsyncEnumerable> chunks, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - _ = Throw.IfNull(chunks); - - await foreach (var chunk in chunks.WithCancellation(cancellationToken)) - { - var response = await _chatClient.GetResponseAsync( - [ - _systemPrompt, - new(ChatRole.User, chunk.Content) - ], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); - - if (!_validSentiments.Contains(response.Text)) - { - throw new InvalidOperationException($"Invalid sentiment response: '{response.Text}'."); - } - - chunk.Metadata[MetadataKey] = response.Text; - - yield return chunk; - } - } + public override IAsyncEnumerable> ProcessAsync(IAsyncEnumerable> chunks, CancellationToken cancellationToken = default) + => Batching.ProcessAsync(chunks, _options, MetadataKey, _systemPrompt, _logger, cancellationToken); } diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/SummaryEnricher.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/SummaryEnricher.cs index f91b9809b05..7e2da4d12f5 100644 --- a/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/SummaryEnricher.cs +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Processors/SummaryEnricher.cs @@ -3,10 +3,9 @@ using System; using System.Collections.Generic; -using System.Runtime.CompilerServices; using System.Threading; -using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.DataIngestion; @@ -19,23 +18,22 @@ namespace Microsoft.Extensions.DataIngestion; /// public sealed class SummaryEnricher : IngestionChunkProcessor { - private readonly IChatClient _chatClient; - private readonly ChatOptions? _chatOptions; + private readonly EnricherOptions _options; private readonly ChatMessage _systemPrompt; + private readonly ILogger? _logger; /// /// Initializes a new instance of the class. /// - /// The chat client used for summary generation. - /// Options for the chat client. + /// The options for summary generation. /// The maximum number of words for the summary. When not provided, it defaults to 100. - public SummaryEnricher(IChatClient chatClient, ChatOptions? chatOptions = null, int? maxWordCount = null) + public SummaryEnricher(EnricherOptions options, int? maxWordCount = null) { - _chatClient = Throw.IfNull(chatClient); - _chatOptions = chatOptions; + _options = Throw.IfNull(options).Clone(); int wordCount = maxWordCount.HasValue ? Throw.IfLessThanOrEqual(maxWordCount.Value, 0, nameof(maxWordCount)) : 100; - _systemPrompt = new(ChatRole.System, $"Write a summary text for this text with no more than {wordCount} words. Return just the summary."); + _systemPrompt = new(ChatRole.System, $"For each of the following texts, write a summary text with no more than {wordCount} words."); + _logger = _options.LoggerFactory?.CreateLogger(); } /// @@ -44,22 +42,6 @@ public SummaryEnricher(IChatClient chatClient, ChatOptions? chatOptions = null, public static string MetadataKey => "summary"; /// - public override async IAsyncEnumerable> ProcessAsync(IAsyncEnumerable> chunks, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - _ = Throw.IfNull(chunks); - - await foreach (var chunk in chunks.WithCancellation(cancellationToken)) - { - var response = await _chatClient.GetResponseAsync( - [ - _systemPrompt, - new(ChatRole.User, chunk.Content) - ], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); - - chunk.Metadata[MetadataKey] = response.Text; - - yield return chunk; - } - } + public override IAsyncEnumerable> ProcessAsync(IAsyncEnumerable> chunks, CancellationToken cancellationToken = default) + => Batching.ProcessAsync(chunks, _options, MetadataKey, _systemPrompt, _logger, cancellationToken); } diff --git a/src/Libraries/Microsoft.Extensions.DataIngestion/Utils/Batching.cs b/src/Libraries/Microsoft.Extensions.DataIngestion/Utils/Batching.cs new file mode 100644 index 00000000000..b210019401b --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.DataIngestion/Utils/Batching.cs @@ -0,0 +1,108 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +#if NET10_0_OR_GREATER +using System.Linq; +#endif +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.DataIngestion; + +internal static class Batching +{ + internal static async IAsyncEnumerable> ProcessAsync(IAsyncEnumerable> chunks, + EnricherOptions options, + string metadataKey, + ChatMessage systemPrompt, + ILogger? logger, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TMetadata : notnull + { + _ = Throw.IfNull(chunks); + + await foreach (var batch in chunks.Chunk(options.BatchSize).WithCancellation(cancellationToken)) + { + List contents = new(batch.Length); + foreach (var chunk in batch) + { + contents.Add(new TextContent(chunk.Content)); + } + + try + { + ChatResponse response = await options.ChatClient.GetResponseAsync( + [ + systemPrompt, + new(ChatRole.User, contents) + ], options.ChatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); + + if (response.Result.Length == contents.Count) + { + for (int i = 0; i < response.Result.Length; i++) + { + batch[i].Metadata[metadataKey] = response.Result[i]; + } + } + else + { + logger?.UnexpectedResultsCount(response.Result.Length, contents.Count); + } + } +#pragma warning disable CA1031 // Do not catch general exception types + catch (Exception ex) +#pragma warning restore CA1031 // Do not catch general exception types + { + // Enricher failures should not fail the whole ingestion pipeline, as they are best-effort enhancements. + logger?.UnexpectedEnricherFailure(ex); + } + + foreach (var chunk in batch) + { + yield return chunk; + } + } + } + +#if !NET10_0_OR_GREATER +#pragma warning disable VSTHRD200 // Use "Async" suffix for async methods + private static IAsyncEnumerable Chunk(this IAsyncEnumerable source, int count) +#pragma warning restore VSTHRD200 // Use "Async" suffix for async methods + { + _ = Throw.IfNull(source); + _ = Throw.IfLessThanOrEqual(count, 0); + + return CoreAsync(source, count); + + static async IAsyncEnumerable CoreAsync(IAsyncEnumerable source, int count, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var buffer = new TSource[count]; + int index = 0; + + await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + buffer[index++] = item; + + if (index == count) + { + index = 0; + yield return buffer; + } + } + + if (index > 0) + { + Array.Resize(ref buffer, index); + yield return buffer; + } + } + } +#endif +} diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/IngestionPipelineTests.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/IngestionPipelineTests.cs index e2e25aa4664..f2f0d85c458 100644 --- a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/IngestionPipelineTests.cs +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/IngestionPipelineTests.cs @@ -211,7 +211,7 @@ async Task Verify(IAsyncEnumerable results) List ingestionResults = await results.ToListAsync(); Assert.Equal(_sampleFiles.Count, ingestionResults.Count); - Assert.All(ingestionResults, result => Assert.NotNull(result.Source)); + Assert.All(ingestionResults, result => Assert.NotEmpty(result.DocumentId)); IngestionResult ingestionResult = Assert.Single(ingestionResults.Where(result => !result.Succeeded)); Assert.IsType(ingestionResult.Exception); AssertErrorActivities(activities, expectedFailedActivitiesCount: 1); @@ -221,16 +221,6 @@ async Task Verify(IAsyncEnumerable results) } } - private class ExpectedException : Exception - { - internal const string ExceptionMessage = "An expected exception occurred."; - - public ExpectedException() - : base(ExceptionMessage) - { - } - } - private static IngestionDocumentReader CreateReader() => new MarkdownReader(); private static IngestionChunker CreateChunker() => new HeaderChunker(new(TiktokenTokenizer.CreateForModel("gpt-4"))); @@ -246,7 +236,7 @@ private static void AssertAllIngestionsSucceeded(List ingestion { Assert.NotEmpty(ingestionResults); Assert.All(ingestionResults, result => Assert.True(result.Succeeded)); - Assert.All(ingestionResults, result => Assert.NotNull(result.Source)); + Assert.All(ingestionResults, result => Assert.NotEmpty(result.DocumentId)); Assert.All(ingestionResults, result => Assert.NotNull(result.Document)); Assert.All(ingestionResults, result => Assert.Null(result.Exception)); } diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Microsoft.Extensions.DataIngestion.Tests.csproj b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Microsoft.Extensions.DataIngestion.Tests.csproj index bb2a082b875..d5864d426dc 100644 --- a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Microsoft.Extensions.DataIngestion.Tests.csproj +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Microsoft.Extensions.DataIngestion.Tests.csproj @@ -14,6 +14,7 @@ + diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/AlternativeTextEnricherTests.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/AlternativeTextEnricherTests.cs index cc59db3f389..6dad7e4af0a 100644 --- a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/AlternativeTextEnricherTests.cs +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/AlternativeTextEnricherTests.cs @@ -3,18 +3,23 @@ using System; using System.Linq; +using System.Text.Json; using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; using Xunit; namespace Microsoft.Extensions.DataIngestion.Processors.Tests; public class AlternativeTextEnricherTests { + private readonly ReadOnlyMemory _imageContent = new byte[256]; + [Fact] - public void ThrowsOnNullChatClient() + public void ThrowsOnNullOptions() { - Assert.Throws("chatClient", () => new ImageAlternativeTextEnricher(null!)); + Assert.Throws("options", () => new ImageAlternativeTextEnricher(null!)); } [Fact] @@ -22,7 +27,7 @@ public async Task ThrowsOnNullDocument() { using TestChatClient chatClient = new(); - ImageAlternativeTextEnricher sut = new(chatClient); + ImageAlternativeTextEnricher sut = new(new(chatClient)); await Assert.ThrowsAsync("document", async () => await sut.ProcessAsync(null!)); } @@ -31,9 +36,7 @@ public async Task ThrowsOnNullDocument() public async Task CanGenerateImageAltText() { const string PreExistingAltText = "Pre-existing alt text"; - ReadOnlyMemory imageContent = new byte[256]; - int counter = 0; string[] descriptions = { "First alt text", "Second alt text" }; using TestChatClient chatClient = new() { @@ -44,37 +47,41 @@ public async Task CanGenerateImageAltText() Assert.Equal(2, materializedMessages.Length); Assert.Equal(ChatRole.System, materializedMessages[0].Role); Assert.Equal(ChatRole.User, materializedMessages[1].Role); - var content = Assert.Single(materializedMessages[1].Contents); - DataContent dataContent = Assert.IsType(content); - Assert.Equal("image/png", dataContent.MediaType); - Assert.Equal(imageContent.ToArray(), dataContent.Data.ToArray()); + Assert.Equal(2, materializedMessages[1].Contents.Count); + + Assert.All(materializedMessages[1].Contents, content => + { + DataContent dataContent = Assert.IsType(content); + Assert.Equal("image/png", dataContent.MediaType); + Assert.Equal(_imageContent.ToArray(), dataContent.Data.ToArray()); + }); return Task.FromResult(new ChatResponse(new[] { - new ChatMessage(ChatRole.Assistant, descriptions[counter++]) + new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(new Envelope { data = descriptions })) })); } }; - ImageAlternativeTextEnricher sut = new(chatClient); + ImageAlternativeTextEnricher sut = new(new(chatClient)); IngestionDocumentImage documentImage = new($"![](nonExisting.png)") { AlternativeText = null, - Content = imageContent, + Content = _imageContent, MediaType = "image/png" }; IngestionDocumentImage tableCell = new($"![](another.png)") { AlternativeText = null, - Content = imageContent, + Content = _imageContent, MediaType = "image/png" }; IngestionDocumentImage imageWithAltText = new($"![](noChangesNeeded.png)") { AlternativeText = PreExistingAltText, - Content = imageContent, + Content = _imageContent, MediaType = "image/png" }; @@ -107,4 +114,91 @@ public async Task CanGenerateImageAltText() Assert.Same(PreExistingAltText, imageWithAltText.AlternativeText); Assert.Null(imageWithNoContent.AlternativeText); } + + [Theory] + [InlineData(1, 3)] + [InlineData(3, 7)] + [InlineData(15, 3)] + public async Task SendsOneRequestPerBatchSize(int batchSize, int batchCount) + { + int callsCount = 0; + using TestChatClient chatClient = new() + { + GetResponseAsyncCallback = (messages, options, cancellationToken) => + { + callsCount++; + + var materializedMessages = messages.ToArray(); + + // One system message + one User message with all the contents + Assert.Equal(2, materializedMessages.Length); + Assert.Equal(ChatRole.System, materializedMessages[0].Role); + Assert.Equal(ChatRole.User, materializedMessages[1].Role); + Assert.Equal(batchSize, materializedMessages[1].Contents.Count); + + Assert.All(materializedMessages[1].Contents, content => + { + DataContent dataContent = Assert.IsType(content); + Assert.Equal("image/png", dataContent.MediaType); + Assert.Equal(_imageContent.ToArray(), dataContent.Data.ToArray()); + }); + + Envelope data = new() { data = Enumerable.Range(0, batchSize).Select(i => i.ToString()).ToArray() }; + return Task.FromResult(new ChatResponse(new[] { new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(data)) })); + } + }; + + ImageAlternativeTextEnricher sut = new(new(chatClient) { BatchSize = batchSize }); + + IngestionDocument document = CreateDocument(batchSize, batchCount, _imageContent); + + await sut.ProcessAsync(document); + Assert.Equal(batchCount, callsCount); + } + + [Fact] + public async Task FailureDoesNotStopTheProcessing() + { + FakeLogCollector collector = new(); + using ILoggerFactory loggerFactory = LoggerFactory.Create(b => b.AddProvider(new FakeLoggerProvider(collector))); + using TestChatClient chatClient = new() + { + GetResponseAsyncCallback = (messages, options, cancellationToken) => Task.FromException(new ExpectedException()) + }; + + EnricherOptions options = new(chatClient) { LoggerFactory = loggerFactory }; + ImageAlternativeTextEnricher sut = new(options); + + const int BatchCount = 2; + IngestionDocument document = CreateDocument(options.BatchSize, BatchCount, _imageContent); + IngestionDocument got = await sut.ProcessAsync(document); + + Assert.Equal(BatchCount, collector.Count); + Assert.All(collector.GetSnapshot(), record => + { + Assert.Equal(LogLevel.Error, record.Level); + Assert.IsType(record.Exception); + }); + } + + private static IngestionDocument CreateDocument(int batchSize, int batchCount, ReadOnlyMemory imageContent) + { + IngestionDocumentSection rootSection = new(); + for (int i = 0; i < batchSize * batchCount; i++) + { + IngestionDocumentImage image = new($"![](image{i}.png)") + { + Content = imageContent, + MediaType = "image/png", + AlternativeText = null + }; + + rootSection.Elements.Add(image); + } + + return new("batchTest") + { + Sections = { rootSection } + }; + } } diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/ClassificationEnricherTests.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/ClassificationEnricherTests.cs index 3f890969262..15d0a5f6152 100644 --- a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/ClassificationEnricherTests.cs +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/ClassificationEnricherTests.cs @@ -4,8 +4,11 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text.Json; using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; using Xunit; namespace Microsoft.Extensions.DataIngestion.Processors.Tests; @@ -15,46 +18,40 @@ public class ClassificationEnricherTests private static readonly IngestionDocument _document = new("test"); [Fact] - public void ThrowsOnNullChatClient() + public void ThrowsOnNullOptions() { - Assert.Throws("chatClient", () => new ClassificationEnricher(null!, predefinedClasses: ["some"])); + Assert.Throws("options", () => new ClassificationEnricher(null!, predefinedClasses: ["some"])); } [Fact] public void ThrowsOnEmptyPredefinedClasses() { - Assert.Throws("predefinedClasses", () => new ClassificationEnricher(new TestChatClient(), predefinedClasses: [])); + Assert.Throws("predefinedClasses", () => new ClassificationEnricher(new(new TestChatClient()), predefinedClasses: [])); } [Fact] public void ThrowsOnDuplicatePredefinedClasses() { - Assert.Throws("predefinedClasses", () => new ClassificationEnricher(new TestChatClient(), predefinedClasses: ["same", "same"])); + Assert.Throws("predefinedClasses", () => new ClassificationEnricher(new(new TestChatClient()), predefinedClasses: ["same", "same"])); } [Fact] public void ThrowsOnPredefinedClassesContainingFallback() { - Assert.Throws("predefinedClasses", () => new ClassificationEnricher(new TestChatClient(), predefinedClasses: ["same", "Unknown"])); + Assert.Throws("predefinedClasses", () => new ClassificationEnricher(new(new TestChatClient()), predefinedClasses: ["same", "Unknown"])); } [Fact] public void ThrowsOnFallbackInPredefinedClasses() { - Assert.Throws("predefinedClasses", () => new ClassificationEnricher(new TestChatClient(), predefinedClasses: ["some"], fallbackClass: "some")); - } - - [Fact] - public void ThrowsOnPredefinedClassesContainingComma() - { - Assert.Throws("predefinedClasses", () => new ClassificationEnricher(new TestChatClient(), predefinedClasses: ["n,t"])); + Assert.Throws("predefinedClasses", () => new ClassificationEnricher(new(new TestChatClient()), predefinedClasses: ["some"], fallbackClass: "some")); } [Fact] public async Task ThrowsOnNullChunks() { using TestChatClient chatClient = new(); - ClassificationEnricher sut = new(chatClient, predefinedClasses: ["some"]); + ClassificationEnricher sut = new(new(chatClient), predefinedClasses: ["some"]); await Assert.ThrowsAsync("chunks", async () => { @@ -74,19 +71,21 @@ public async Task CanClassify() { GetResponseAsyncCallback = (messages, options, cancellationToken) => { + Assert.Equal(0, counter++); var materializedMessages = messages.ToArray(); Assert.Equal(2, materializedMessages.Length); Assert.Equal(ChatRole.System, materializedMessages[0].Role); Assert.Equal(ChatRole.User, materializedMessages[1].Role); + string response = JsonSerializer.Serialize(new Envelope { data = classes }); return Task.FromResult(new ChatResponse(new[] { - new ChatMessage(ChatRole.Assistant, classes[counter++]) + new ChatMessage(ChatRole.Assistant, response) })); } }; - ClassificationEnricher sut = new(chatClient, ["AI", "Animals", "Sports"], fallbackClass: "UFO"); + ClassificationEnricher sut = new(new(chatClient), ["AI", "Animals", "Sports"], fallbackClass: "UFO"); IReadOnlyList> got = await sut.ProcessAsync(CreateChunks().ToAsyncEnumerable()).ToListAsync(); @@ -97,29 +96,25 @@ public async Task CanClassify() } [Fact] - public async Task ThrowsOnInvalidResponse() + public async Task FailureDoesNotStopTheProcessing() { + FakeLogCollector collector = new(); + using ILoggerFactory loggerFactory = LoggerFactory.Create(b => b.AddProvider(new FakeLoggerProvider(collector))); using TestChatClient chatClient = new() { - GetResponseAsyncCallback = (messages, options, cancellationToken) => - { - return Task.FromResult(new ChatResponse(new[] - { - new ChatMessage(ChatRole.Assistant, "Unexpected result!") - })); - } + GetResponseAsyncCallback = (messages, options, cancellationToken) => Task.FromException(new ExpectedException()) }; - ClassificationEnricher sut = new(chatClient, ["AI", "Animals", "Sports"]); - var input = CreateChunks().ToAsyncEnumerable(); + ClassificationEnricher sut = new(new(chatClient) { LoggerFactory = loggerFactory }, ["AI", "Other"]); + List> chunks = CreateChunks(); - await Assert.ThrowsAsync(async () => - { - await foreach (var _ in sut.ProcessAsync(input)) - { - // No-op - } - }); + IReadOnlyList> got = await sut.ProcessAsync(chunks.ToAsyncEnumerable()).ToListAsync(); + + Assert.Equal(chunks.Count, got.Count); + Assert.All(chunks, chunk => Assert.False(chunk.HasMetadata)); + Assert.Equal(1, collector.Count); // with batching, only one log entry is expected + Assert.Equal(LogLevel.Error, collector.LatestRecord.Level); + Assert.IsType(collector.LatestRecord.Exception); } private static List> CreateChunks() => diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/KeywordEnricherTests.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/KeywordEnricherTests.cs index 0f11cd7d46b..5a116e1ab04 100644 --- a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/KeywordEnricherTests.cs +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/KeywordEnricherTests.cs @@ -4,8 +4,11 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text.Json; using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; using Xunit; namespace Microsoft.Extensions.DataIngestion.Processors.Tests; @@ -15,9 +18,9 @@ public class KeywordEnricherTests private static readonly IngestionDocument _document = new("test"); [Fact] - public void ThrowsOnNullChatClient() + public void ThrowsOnNullOptions() { - Assert.Throws("chatClient", () => new KeywordEnricher(null!, predefinedKeywords: null, confidenceThreshold: 0.5)); + Assert.Throws("options", () => new KeywordEnricher(null!, predefinedKeywords: null, confidenceThreshold: 0.5)); } [Theory] @@ -25,7 +28,7 @@ public void ThrowsOnNullChatClient() [InlineData(1.1)] public void ThrowsOnInvalidThreshold(double threshold) { - Assert.Throws("confidenceThreshold", () => new KeywordEnricher(new TestChatClient(), predefinedKeywords: null, confidenceThreshold: threshold)); + Assert.Throws("confidenceThreshold", () => new KeywordEnricher(new(new TestChatClient()), predefinedKeywords: null, confidenceThreshold: threshold)); } [Theory] @@ -33,28 +36,20 @@ public void ThrowsOnInvalidThreshold(double threshold) [InlineData(-1)] public void ThrowsOnInvalidMaxKeywords(int keywordCount) { - Assert.Throws("maxKeywords", () => new KeywordEnricher(new TestChatClient(), predefinedKeywords: null, maxKeywords: keywordCount)); + Assert.Throws("maxKeywords", () => new KeywordEnricher(new(new TestChatClient()), predefinedKeywords: null, maxKeywords: keywordCount)); } [Fact] public void ThrowsOnDuplicateKeywords() { - Assert.Throws("predefinedKeywords", () => new KeywordEnricher(new TestChatClient(), predefinedKeywords: ["same", "same"], confidenceThreshold: 0.5)); - } - - [Theory] - [InlineData(',')] - [InlineData(';')] - public void ThrowsOnIllegalCharacters(char illegal) - { - Assert.Throws("predefinedKeywords", () => new KeywordEnricher(new TestChatClient(), predefinedKeywords: [$"n{illegal}t"])); + Assert.Throws("predefinedKeywords", () => new KeywordEnricher(new(new TestChatClient()), predefinedKeywords: ["same", "same"], confidenceThreshold: 0.5)); } [Fact] public async Task ThrowsOnNullChunks() { using TestChatClient chatClient = new(); - KeywordEnricher sut = new(chatClient, predefinedKeywords: null, confidenceThreshold: 0.5); + KeywordEnricher sut = new(new(chatClient), predefinedKeywords: null, confidenceThreshold: 0.5); await Assert.ThrowsAsync("chunks", async () => { @@ -71,25 +66,27 @@ await Assert.ThrowsAsync("chunks", async () => public async Task CanExtractKeywords(params string[] predefined) { int counter = 0; - string[] keywords = { "AI;MEAI", "Animals;Rabbits" }; + string[][] keywords = [["AI", "MEAI"], ["Animals", "Rabbits"]]; using TestChatClient chatClient = new() { GetResponseAsyncCallback = (messages, options, cancellationToken) => { + Assert.Equal(0, counter++); var materializedMessages = messages.ToArray(); Assert.Equal(2, materializedMessages.Length); Assert.Equal(ChatRole.System, materializedMessages[0].Role); Assert.Equal(ChatRole.User, materializedMessages[1].Role); + string response = JsonSerializer.Serialize(new Envelope { data = keywords }); return Task.FromResult(new ChatResponse(new[] { - new ChatMessage(ChatRole.Assistant, keywords[counter++]) + new ChatMessage(ChatRole.Assistant, response) })); } }; - KeywordEnricher sut = new(chatClient, predefinedKeywords: predefined, confidenceThreshold: 0.5); + KeywordEnricher sut = new(new(chatClient), predefinedKeywords: predefined, confidenceThreshold: 0.5); var chunks = CreateChunks().ToAsyncEnumerable(); IReadOnlyList> got = await sut.ProcessAsync(chunks).ToListAsync(); @@ -99,29 +96,25 @@ public async Task CanExtractKeywords(params string[] predefined) } [Fact] - public async Task ThrowsOnInvalidResponse() + public async Task FailureDoesNotStopTheProcessing() { + FakeLogCollector collector = new(); + using ILoggerFactory loggerFactory = LoggerFactory.Create(b => b.AddProvider(new FakeLoggerProvider(collector))); using TestChatClient chatClient = new() { - GetResponseAsyncCallback = (messages, options, cancellationToken) => - { - return Task.FromResult(new ChatResponse(new[] - { - new ChatMessage(ChatRole.Assistant, "Unexpected result!") - })); - } + GetResponseAsyncCallback = (messages, options, cancellationToken) => Task.FromException(new ExpectedException()) }; - KeywordEnricher sut = new(chatClient, ["some"]); - var input = CreateChunks().ToAsyncEnumerable(); + KeywordEnricher sut = new(new(chatClient) { LoggerFactory = loggerFactory }, ["AI", "Other"]); + List> chunks = CreateChunks(); - await Assert.ThrowsAsync(async () => - { - await foreach (var _ in sut.ProcessAsync(input)) - { - // No-op - } - }); + IReadOnlyList> got = await sut.ProcessAsync(chunks.ToAsyncEnumerable()).ToListAsync(); + + Assert.Equal(chunks.Count, got.Count); + Assert.All(chunks, chunk => Assert.False(chunk.HasMetadata)); + Assert.Equal(1, collector.Count); // with batching, only one log entry is expected + Assert.Equal(LogLevel.Error, collector.LatestRecord.Level); + Assert.IsType(collector.LatestRecord.Exception); } private static List> CreateChunks() => diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/SentimentEnricherTests.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/SentimentEnricherTests.cs index 166b3c05959..8d762f3199c 100644 --- a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/SentimentEnricherTests.cs +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/SentimentEnricherTests.cs @@ -4,8 +4,11 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text.Json; using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; using Xunit; namespace Microsoft.Extensions.DataIngestion.Processors.Tests; @@ -15,9 +18,9 @@ public class SentimentEnricherTests private static readonly IngestionDocument _document = new("test"); [Fact] - public void ThrowsOnNullChatClient() + public void ThrowsOnNullOptions() { - Assert.Throws("chatClient", () => new SentimentEnricher(null!)); + Assert.Throws("options", () => new SentimentEnricher(null!)); } [Theory] @@ -25,14 +28,14 @@ public void ThrowsOnNullChatClient() [InlineData(1.1)] public void ThrowsOnInvalidThreshold(double threshold) { - Assert.Throws("confidenceThreshold", () => new SentimentEnricher(new TestChatClient(), confidenceThreshold: threshold)); + Assert.Throws("confidenceThreshold", () => new SentimentEnricher(new(new TestChatClient()), confidenceThreshold: threshold)); } [Fact] public async Task ThrowsOnNullChunks() { using TestChatClient chatClient = new(); - SentimentEnricher sut = new(chatClient); + SentimentEnricher sut = new(new(chatClient)); await Assert.ThrowsAsync("chunks", async () => { @@ -52,19 +55,21 @@ public async Task CanProvideSentiment() { GetResponseAsyncCallback = (messages, options, cancellationToken) => { + Assert.Equal(0, counter++); var materializedMessages = messages.ToArray(); Assert.Equal(2, materializedMessages.Length); Assert.Equal(ChatRole.System, materializedMessages[0].Role); Assert.Equal(ChatRole.User, materializedMessages[1].Role); + string response = JsonSerializer.Serialize(new Envelope { data = sentiments }); return Task.FromResult(new ChatResponse(new[] { - new ChatMessage(ChatRole.Assistant, sentiments[counter++]) + new ChatMessage(ChatRole.Assistant, response) })); } }; - SentimentEnricher sut = new(chatClient); + SentimentEnricher sut = new(new(chatClient)); var input = CreateChunks().ToAsyncEnumerable(); var chunks = await sut.ProcessAsync(input).ToListAsync(); @@ -78,29 +83,25 @@ public async Task CanProvideSentiment() } [Fact] - public async Task ThrowsOnInvalidResponse() + public async Task FailureDoesNotStopTheProcessing() { + FakeLogCollector collector = new(); + using ILoggerFactory loggerFactory = LoggerFactory.Create(b => b.AddProvider(new FakeLoggerProvider(collector))); using TestChatClient chatClient = new() { - GetResponseAsyncCallback = (messages, options, cancellationToken) => - { - return Task.FromResult(new ChatResponse(new[] - { - new ChatMessage(ChatRole.Assistant, "Unexpected result!") - })); - } + GetResponseAsyncCallback = (messages, options, cancellationToken) => Task.FromException(new ExpectedException()) }; - SentimentEnricher sut = new(chatClient); - var input = CreateChunks().ToAsyncEnumerable(); + SentimentEnricher sut = new(new(chatClient) { LoggerFactory = loggerFactory }); + List> chunks = CreateChunks(); - await Assert.ThrowsAsync(async () => - { - await foreach (var _ in sut.ProcessAsync(input)) - { - // No-op - } - }); + IReadOnlyList> got = await sut.ProcessAsync(chunks.ToAsyncEnumerable()).ToListAsync(); + + Assert.Equal(chunks.Count, got.Count); + Assert.All(chunks, chunk => Assert.False(chunk.HasMetadata)); + Assert.Equal(1, collector.Count); // with batching, only one log entry is expected + Assert.Equal(LogLevel.Error, collector.LatestRecord.Level); + Assert.IsType(collector.LatestRecord.Exception); } private static List> CreateChunks() => diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/SummaryEnricherTests.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/SummaryEnricherTests.cs index 6fda37004d3..8b0dcd904c4 100644 --- a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/SummaryEnricherTests.cs +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Processors/SummaryEnricherTests.cs @@ -4,8 +4,11 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text.Json; using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; using Xunit; namespace Microsoft.Extensions.DataIngestion.Processors.Tests; @@ -15,9 +18,9 @@ public class SummaryEnricherTests private static readonly IngestionDocument _document = new("test"); [Fact] - public void ThrowsOnNullChatClient() + public void ThrowsOnNullOptions() { - Assert.Throws("chatClient", () => new SummaryEnricher(null!)); + Assert.Throws("options", () => new SummaryEnricher(null!)); } [Theory] @@ -25,14 +28,14 @@ public void ThrowsOnNullChatClient() [InlineData(-1)] public void ThrowsOnInvalidMaxKeywords(int wordCount) { - Assert.Throws("maxWordCount", () => new SummaryEnricher(new TestChatClient(), maxWordCount: wordCount)); + Assert.Throws("maxWordCount", () => new SummaryEnricher(new(new TestChatClient()), maxWordCount: wordCount)); } [Fact] public async Task ThrowsOnNullChunks() { using TestChatClient chatClient = new(); - SummaryEnricher sut = new(chatClient); + SummaryEnricher sut = new(new(chatClient)); await Assert.ThrowsAsync("chunks", async () => { @@ -52,19 +55,21 @@ public async Task CanProvideSummary() { GetResponseAsyncCallback = (messages, options, cancellationToken) => { + Assert.Equal(0, counter++); var materializedMessages = messages.ToArray(); Assert.Equal(2, materializedMessages.Length); Assert.Equal(ChatRole.System, materializedMessages[0].Role); Assert.Equal(ChatRole.User, materializedMessages[1].Role); + string response = JsonSerializer.Serialize(new Envelope { data = summaries }); return Task.FromResult(new ChatResponse(new[] { - new ChatMessage(ChatRole.Assistant, summaries[counter++]) + new ChatMessage(ChatRole.Assistant, response) })); } }; - SummaryEnricher sut = new(chatClient); + SummaryEnricher sut = new(new(chatClient)); var input = CreateChunks().ToAsyncEnumerable(); var chunks = await sut.ProcessAsync(input).ToListAsync(); @@ -74,6 +79,28 @@ public async Task CanProvideSummary() Assert.Equal(summaries[1], (string)chunks[1].Metadata[SummaryEnricher.MetadataKey]!); } + [Fact] + public async Task FailureDoesNotStopTheProcessing() + { + FakeLogCollector collector = new(); + using ILoggerFactory loggerFactory = LoggerFactory.Create(b => b.AddProvider(new FakeLoggerProvider(collector))); + using TestChatClient chatClient = new() + { + GetResponseAsyncCallback = (messages, options, cancellationToken) => Task.FromException(new ExpectedException()) + }; + + SummaryEnricher sut = new(new(chatClient) { LoggerFactory = loggerFactory }); + List> chunks = CreateChunks(); + + IReadOnlyList> got = await sut.ProcessAsync(chunks.ToAsyncEnumerable()).ToListAsync(); + + Assert.Equal(chunks.Count, got.Count); + Assert.All(chunks, chunk => Assert.False(chunk.HasMetadata)); + Assert.Equal(1, collector.Count); // with batching, only one log entry is expected + Assert.Equal(LogLevel.Error, collector.LatestRecord.Level); + Assert.IsType(collector.LatestRecord.Exception); + } + private static List> CreateChunks() => [ new("I love programming! It's so much fun and rewarding.", _document), diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Utils/Envelope{T}.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Utils/Envelope{T}.cs new file mode 100644 index 00000000000..d6fade6892f --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Utils/Envelope{T}.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.DataIngestion; + +internal class Envelope +{ +#pragma warning disable IDE1006 // Naming Styles + public T? data { get; set; } +#pragma warning restore IDE1006 // Naming Styles +} diff --git a/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Utils/ExpectedException.cs b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Utils/ExpectedException.cs new file mode 100644 index 00000000000..79d2e7538fd --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.DataIngestion.Tests/Utils/ExpectedException.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.DataIngestion; + +internal sealed class ExpectedException : Exception +{ + internal const string ExceptionMessage = "An expected exception occurred."; + + internal ExpectedException() + : base(ExceptionMessage) + { + } +}