Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

<ItemGroup>
<ProjectReference Include="..\Microsoft.Extensions.DataIngestion.Abstractions\Microsoft.Extensions.DataIngestion.Abstractions.csproj" />
<ProjectReference Include="..\Microsoft.Extensions.AI\Microsoft.Extensions.AI.csproj" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="System.Collections.Immutable" Condition="'$(TargetFrameworkIdentifier)' != '.NETCoreApp'" />
<PackageReference Include="Microsoft.Extensions.VectorData.Abstractions" />
<PackageReference Include="Microsoft.ML.Tokenizers" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Frozen;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
Expand All @@ -21,29 +20,25 @@ namespace Microsoft.Extensions.DataIngestion;
/// an optional fallback class for cases where no suitable classification can be determined.</remarks>
public sealed class ClassificationEnricher : IngestionChunkProcessor<string>
{
private readonly IChatClient _chatClient;
private readonly ChatOptions? _chatOptions;
private readonly FrozenSet<string> _predefinedClasses;
private readonly EnricherOptions _options;
private readonly ChatMessage _systemPrompt;

/// <summary>
/// Initializes a new instance of the <see cref="ClassificationEnricher"/> class.
/// </summary>
/// <param name="chatClient">The chat client used for classification.</param>
/// <param name="options">The options for the classification enricher.</param>
/// <param name="predefinedClasses">The set of predefined classification classes.</param>
/// <param name="chatOptions">Options for the chat client.</param>
/// <param name="fallbackClass">The fallback class to use when no suitable classification is found. When not provided, it defaults to "Unknown".</param>
public ClassificationEnricher(IChatClient chatClient, ReadOnlySpan<string> predefinedClasses,
ChatOptions? chatOptions = null, string? fallbackClass = null)
public ClassificationEnricher(EnricherOptions options, ReadOnlySpan<string> predefinedClasses,
string? fallbackClass = null)
{
_chatClient = Throw.IfNull(chatClient);
_chatOptions = chatOptions;
_options = Throw.IfNull(options).Clone();
if (string.IsNullOrWhiteSpace(fallbackClass))
{
fallbackClass = "Unknown";
}

_predefinedClasses = CreatePredefinedSet(predefinedClasses, fallbackClass!);
Validate(predefinedClasses, fallbackClass!);
_systemPrompt = CreateSystemPrompt(predefinedClasses, fallbackClass!);
}

Expand All @@ -58,23 +53,34 @@ public override async IAsyncEnumerable<IngestionChunk<string>> ProcessAsync(IAsy
{
_ = Throw.IfNull(chunks);

await foreach (IngestionChunk<string> chunk in chunks.WithCancellation(cancellationToken))
await foreach (var batch in chunks.BufferAsync(_options.BatchSize).WithCancellation(cancellationToken))
{
var response = await _chatClient.GetResponseAsync(
List<AIContent> contents = new(batch.Count);
foreach (var chunk in batch)
{
contents.Add(new TextContent(chunk.Content));
}

var response = await _options.ChatClient.GetResponseAsync<string[]>(
[
_systemPrompt,
new(ChatRole.User, chunk.Content)
], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false);
new(ChatRole.User, contents)
], _options.ChatOptions, cancellationToken: cancellationToken).ConfigureAwait(false);

chunk.Metadata[MetadataKey] = _predefinedClasses.Contains(response.Text)
? response.Text
: throw new InvalidOperationException($"Classification returned an unexpected class: '{response.Text}'.");
if (response.Result.Length != contents.Count)
{
throw new InvalidOperationException($"The AI chat service returned {response.Result.Length} instead of {contents.Count} results.");
}

yield return chunk;
for (int i = 0; i < response.Result.Length; i++)
{
batch[i].Metadata[MetadataKey] = response.Result[i];
yield return batch[i];
}
}
}

private static FrozenSet<string> CreatePredefinedSet(ReadOnlySpan<string> predefinedClasses, string fallbackClass)
private static void Validate(ReadOnlySpan<string> predefinedClasses, string fallbackClass)
{
if (predefinedClasses.Length == 0)
{
Expand All @@ -84,15 +90,6 @@ private static FrozenSet<string> CreatePredefinedSet(ReadOnlySpan<string> predef
HashSet<string> predefinedClassesSet = new(StringComparer.Ordinal) { fallbackClass };
foreach (string predefinedClass in predefinedClasses)
{
#if NET
if (predefinedClass.Contains(',', StringComparison.Ordinal))
#else
if (predefinedClass.IndexOf(',') >= 0)
#endif
{
Throw.ArgumentException(nameof(predefinedClasses), $"Predefined class '{predefinedClass}' must not contain ',' character.");
}

if (!predefinedClassesSet.Add(predefinedClass))
{
if (predefinedClass.Equals(fallbackClass, StringComparison.Ordinal))
Expand All @@ -103,13 +100,11 @@ private static FrozenSet<string> CreatePredefinedSet(ReadOnlySpan<string> predef
Throw.ArgumentException(nameof(predefinedClasses), $"Duplicate class found: '{predefinedClass}'.");
}
}

return predefinedClassesSet.ToFrozenSet();
}

private static ChatMessage CreateSystemPrompt(ReadOnlySpan<string> predefinedClasses, string fallbackClass)
{
StringBuilder sb = new("You are a classification expert. Analyze the given text and assign a single, most relevant class. Use only the following predefined classes: ");
StringBuilder sb = new("You are a classification expert. For each of the following texts, assign a single, most relevant class. Use only the following predefined classes: ");

#if NET9_0_OR_GREATER
sb.AppendJoin(", ", predefinedClasses!);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.Extensions.AI;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Extensions.DataIngestion;

/// <summary>
/// Represents options for enrichers that use an AI chat client.
/// </summary>
public class EnricherOptions
{
/// <summary>
/// Initializes a new instance of the <see cref="EnricherOptions"/> class.
/// </summary>
/// <param name="chatClient">The AI chat client to be used.</param>
public EnricherOptions(IChatClient chatClient)
{
ChatClient = Throw.IfNull(chatClient);
}

/// <summary>
/// Gets the AI chat client to be used.
/// </summary>
public IChatClient ChatClient { get; }

/// <summary>
/// Gets or sets the options for the <see cref="ChatClient"/>.
/// </summary>
public ChatOptions? ChatOptions { get; set; }

/// <summary>
/// Gets or sets the batch size for processing chunks. Default is 20.
/// </summary>
public int BatchSize { get; set => field = Throw.IfLessThanOrEqual(value, 0); } = 20;

internal EnricherOptions Clone() => new(ChatClient)
{
ChatOptions = ChatOptions?.Clone(),
BatchSize = BatchSize
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
Expand All @@ -15,60 +16,93 @@ namespace Microsoft.Extensions.DataIngestion;
/// </summary>
public sealed class ImageAlternativeTextEnricher : IngestionDocumentProcessor
{
private readonly IChatClient _chatClient;
private readonly ChatOptions? _chatOptions;
private readonly EnricherOptions _options;
private readonly ChatMessage _systemPrompt;

/// <summary>
/// Initializes a new instance of the <see cref="ImageAlternativeTextEnricher"/> class.
/// </summary>
/// <param name="chatClient">The chat client used to get responses for generating alternative text.</param>
/// <param name="chatOptions">Options for the chat client.</param>
public ImageAlternativeTextEnricher(IChatClient chatClient, ChatOptions? chatOptions = null)
/// <param name="options">The options for generating alternative text.</param>
public ImageAlternativeTextEnricher(EnricherOptions options)
{
_chatClient = Throw.IfNull(chatClient);
_chatOptions = chatOptions;
_systemPrompt = new(ChatRole.System, "Write a detailed alternative text for this image with less than 50 words.");
_options = Throw.IfNull(options).Clone();
_systemPrompt = new(ChatRole.System, "For each of the following images, write a detailed alternative text with fewer than 50 words.");
}

/// <inheritdoc/>
public override async Task<IngestionDocument> ProcessAsync(IngestionDocument document, CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(document);

List<IngestionDocumentImage>? batch = null;

foreach (var element in document.EnumerateContent())
{
if (element is IngestionDocumentImage image)
{
await ProcessAsync(image, cancellationToken).ConfigureAwait(false);
if (ShouldProcess(image))
{
batch ??= new(_options.BatchSize);
batch.Add(image);

if (batch.Count == _options.BatchSize)
{
await ProcessAsync(batch, cancellationToken).ConfigureAwait(false);
batch.Clear();
}
}
}
else if (element is IngestionDocumentTable table)
{
foreach (var cell in table.Cells)
{
if (cell is IngestionDocumentImage cellImage)
if (cell is IngestionDocumentImage cellImage && ShouldProcess(cellImage))
{
await ProcessAsync(cellImage, cancellationToken).ConfigureAwait(false);
batch ??= new(_options.BatchSize);
batch.Add(cellImage);

if (batch.Count == _options.BatchSize)
{
await ProcessAsync(batch, cancellationToken).ConfigureAwait(false);
batch.Clear();
}
}
}
}
}

if (batch?.Count > 0)
{
await ProcessAsync(batch, cancellationToken).ConfigureAwait(false);
}

return document;
}

private async Task ProcessAsync(IngestionDocumentImage image, CancellationToken cancellationToken)
private static bool ShouldProcess(IngestionDocumentImage img) =>
img.Content.HasValue && !string.IsNullOrEmpty(img.MediaType) && string.IsNullOrEmpty(img.AlternativeText);

private async Task ProcessAsync(List<IngestionDocumentImage> batch, CancellationToken cancellationToken)
{
if (image.Content.HasValue && !string.IsNullOrEmpty(image.MediaType)
&& string.IsNullOrEmpty(image.AlternativeText))
List<AIContent> contents = new(batch.Count);
foreach (var image in batch)
{
contents.Add(new DataContent(image.Content!.Value, image.MediaType!));
}

var response = await _options.ChatClient.GetResponseAsync<string[]>(
[_systemPrompt, new(ChatRole.User, contents)],
_options.ChatOptions,
cancellationToken: cancellationToken).ConfigureAwait(false);

if (response.Result.Length != contents.Count)
{
var response = await _chatClient.GetResponseAsync(
[
_systemPrompt,
new(ChatRole.User, [new DataContent(image.Content.Value, image.MediaType!)])
], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false);
throw new InvalidOperationException($"The AI chat service returned {response.Result.Length} instead of {contents.Count} results.");
}

image.AlternativeText = response.Text;
for (int i = 0; i < response.Result.Length; i++)
{
batch[i].AlternativeText = response.Result[i];
}
}
}
Loading
Loading