diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGenerationOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGenerationOptions.cs
index 02875e9de98..27b84273b5b 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGenerationOptions.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGenerationOptions.cs
@@ -1,11 +1,30 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using Microsoft.Shared.Diagnostics;
+
namespace Microsoft.Extensions.AI;
/// Represents the options for an embedding generation request.
public class EmbeddingGenerationOptions
{
+ private int? _dimensions;
+
+ /// Gets or sets the number of dimensions requested in the embedding.
+ public int? Dimensions
+ {
+ get => _dimensions;
+ set
+ {
+ if (value is not null)
+ {
+ _ = Throw.IfLessThan(value.Value, 1);
+ }
+
+ _dimensions = value;
+ }
+ }
+
/// Gets or sets the model ID for the embedding generation request.
public string? ModelId { get; set; }
@@ -22,6 +41,7 @@ public virtual EmbeddingGenerationOptions Clone() =>
new()
{
ModelId = ModelId,
+ Dimensions = Dimensions,
AdditionalProperties = AdditionalProperties?.Clone(),
};
}
diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs
index 27bf001b3ff..155e047279f 100644
--- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs
@@ -135,17 +135,11 @@ void IDisposable.Dispose()
{
OpenAI.Embeddings.EmbeddingGenerationOptions openAIOptions = new()
{
- Dimensions = _dimensions,
+ Dimensions = options?.Dimensions ?? _dimensions,
};
if (options?.AdditionalProperties is { Count: > 0 } additionalProperties)
{
- // Allow per-instance dimensions to be overridden by a per-call property
- if (additionalProperties.TryGetValue(nameof(openAIOptions.Dimensions), out int? dimensions))
- {
- openAIOptions.Dimensions = dimensions;
- }
-
if (additionalProperties.TryGetValue(nameof(openAIOptions.EndUserId), out string? endUserId))
{
openAIOptions.EndUserId = endUserId;
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGenerationOptionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGenerationOptionsTests.cs
index e9dd45959c7..fbc8b390abf 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGenerationOptionsTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGenerationOptionsTests.cs
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System;
using System.Text.Json;
using Xunit;
@@ -14,10 +15,20 @@ public void Constructor_Parameterless_PropsDefaulted()
EmbeddingGenerationOptions options = new();
Assert.Null(options.ModelId);
Assert.Null(options.AdditionalProperties);
+ Assert.Null(options.Dimensions);
EmbeddingGenerationOptions clone = options.Clone();
Assert.Null(clone.ModelId);
Assert.Null(clone.AdditionalProperties);
+ Assert.Null(clone.Dimensions);
+ }
+
+ [Fact]
+ public void InvalidArgs_Throws()
+ {
+ EmbeddingGenerationOptions options = new();
+ Assert.Throws(() => options.Dimensions = 0);
+ Assert.Throws(() => options.Dimensions = -1);
}
[Fact]
@@ -31,13 +42,16 @@ public void Properties_Roundtrip()
};
options.ModelId = "modelId";
+ options.Dimensions = 1536;
options.AdditionalProperties = additionalProps;
Assert.Equal("modelId", options.ModelId);
+ Assert.Equal(1536, options.Dimensions);
Assert.Same(additionalProps, options.AdditionalProperties);
EmbeddingGenerationOptions clone = options.Clone();
Assert.Equal("modelId", clone.ModelId);
+ Assert.Equal(1536, clone.Dimensions);
Assert.Equal(additionalProps, clone.AdditionalProperties);
}
@@ -53,6 +67,7 @@ public void JsonSerialization_Roundtrips()
options.ModelId = "model";
options.AdditionalProperties = additionalProps;
+ options.Dimensions = 1536;
string json = JsonSerializer.Serialize(options, TestJsonSerializerContext.Default.EmbeddingGenerationOptions);
@@ -60,6 +75,7 @@ public void JsonSerialization_Roundtrips()
Assert.NotNull(deserialized);
Assert.Equal("model", deserialized.ModelId);
+ Assert.Equal(1536, deserialized.Dimensions);
Assert.NotNull(deserialized.AdditionalProperties);
Assert.Single(deserialized.AdditionalProperties);