Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Microsoft.Extensions.Hosting;
/// <param name="hostBuilder">The <see cref="IHostApplicationBuilder"/> with which services are being registered.</param>
/// <param name="serviceKey">The service key used to register the <see cref="OllamaApiClient"/> service, if any.</param>
/// <param name="disableTracing">A flag to indicate whether tracing should be disabled.</param>
public class AspireOllamaApiClientBuilder(IHostApplicationBuilder hostBuilder, string serviceKey, bool disableTracing)
public class AspireOllamaApiClientBuilder(IHostApplicationBuilder hostBuilder, object serviceKey, bool disableTracing)
{
/// <summary>
/// The host application builder used to configure the application.
Expand All @@ -18,7 +18,7 @@ public class AspireOllamaApiClientBuilder(IHostApplicationBuilder hostBuilder, s
/// <summary>
/// Gets the service key used to register the <see cref="OllamaApiClient"/> service, if any.
/// </summary>
public string ServiceKey { get; } = serviceKey;
public object ServiceKey { get; } = serviceKey;

/// <summary>
/// Gets a flag indicating whether tracing should be disabled.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,25 @@ public static ChatClientBuilder AddKeyedChatClient(
this AspireOllamaApiClientBuilder builder)
{
ArgumentNullException.ThrowIfNull(builder, nameof(builder));
ArgumentException.ThrowIfNullOrEmpty(builder.ServiceKey, nameof(builder.ServiceKey));

return builder.AddKeyedChatClient(builder.ServiceKey);
}

/// <summary>
/// Registers a keyed singleton <see cref="IChatClient"/> in the services provided by the <paramref name="builder"/> using the specified service key.
/// </summary>
/// <param name="builder">An <see cref="AspireOllamaApiClientBuilder" />.</param>
/// <param name="serviceKey">The service key to use for registering the <see cref="IChatClient"/>.</param>
/// <returns>A <see cref="ChatClientBuilder"/> that can be used to build a pipeline around the inner <see cref="IChatClient"/>.</returns>
public static ChatClientBuilder AddKeyedChatClient(
this AspireOllamaApiClientBuilder builder,
object serviceKey)
{
ArgumentNullException.ThrowIfNull(builder, nameof(builder));
ArgumentNullException.ThrowIfNull(serviceKey, nameof(serviceKey));

return builder.HostBuilder.Services.AddKeyedChatClient(
builder.ServiceKey,
serviceKey,
services => CreateInnerChatClient(services, builder));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,24 @@ public static EmbeddingGeneratorBuilder<string, Embedding<float>> AddKeyedEmbedd
this AspireOllamaApiClientBuilder builder)
{
ArgumentNullException.ThrowIfNull(builder, nameof(builder));
ArgumentException.ThrowIfNullOrEmpty(builder.ServiceKey, nameof(builder.ServiceKey));
return builder.AddKeyedEmbeddingGenerator(builder.ServiceKey);
}

/// <summary>
/// Registers a keyed singleton <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> in the services provided by the <paramref name="builder"/> using the specified service key.
/// </summary>
/// <param name="builder">An <see cref="AspireOllamaApiClientBuilder" />.</param>
/// <param name="serviceKey">The service key to use for registering the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>.</param>
/// <returns>A <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/> that can be used to build a pipeline around the inner <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>.</returns>
public static EmbeddingGeneratorBuilder<string, Embedding<float>> AddKeyedEmbeddingGenerator(
this AspireOllamaApiClientBuilder builder,
object serviceKey)
{
ArgumentNullException.ThrowIfNull(builder, nameof(builder));
ArgumentNullException.ThrowIfNull(serviceKey, nameof(serviceKey));

return builder.HostBuilder.Services.AddKeyedEmbeddingGenerator(
builder.ServiceKey,
serviceKey,
services => CreateInnerEmbeddingGenerator(services, builder));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using OllamaSharp;
using System.Data.Common;

Expand Down Expand Up @@ -43,6 +42,37 @@ public static AspireOllamaApiClientBuilder AddKeyedOllamaApiClient(this IHostApp
return AddOllamaClientInternal(builder, $"{DefaultConfigSectionName}:{connectionName}", connectionName, serviceKey: connectionName, configureSettings: configureSettings);
}

/// <summary>
/// Adds <see cref="IOllamaApiClient"/> services to the container using the specified <paramref name="serviceKey"/>.
/// </summary>
/// <param name="builder">The <see cref="IHostApplicationBuilder" /> to read config from and add services to.</param>
/// <param name="serviceKey">A unique key that identifies this instance of the Ollama client service.</param>
/// <param name="connectionName">A name used to retrieve the connection string from the ConnectionStrings configuration section.</param>
/// <param name="configureSettings">An optional delegate that can be used for customizing options. It's invoked after the settings are read from the configuration.</param>
/// <exception cref="UriFormatException">Thrown when no Ollama endpoint is provided.</exception>
public static AspireOllamaApiClientBuilder AddKeyedOllamaApiClient(this IHostApplicationBuilder builder, object serviceKey, string connectionName, Action<OllamaSharpSettings>? configureSettings = null)
{
ArgumentNullException.ThrowIfNull(serviceKey, nameof(serviceKey));
ArgumentException.ThrowIfNullOrWhiteSpace(connectionName, nameof(connectionName));
ArgumentNullException.ThrowIfNull(builder, nameof(builder));
return AddOllamaClientInternal(builder, $"{DefaultConfigSectionName}:{connectionName}", connectionName, serviceKey: serviceKey, configureSettings: configureSettings);
}

/// <summary>
/// Adds <see cref="IOllamaApiClient"/> services to the container using the specified <paramref name="serviceKey"/>.
/// </summary>
/// <param name="builder">The <see cref="IHostApplicationBuilder" /> to read config from and add services to.</param>
/// <param name="serviceKey">A unique key that identifies this instance of the Ollama client service.</param>
/// <param name="settings">The settings required to configure the <see cref="IOllamaApiClient"/>.</param>
/// <exception cref="UriFormatException">Thrown when no Ollama endpoint is provided.</exception>
public static AspireOllamaApiClientBuilder AddKeyedOllamaApiClient(this IHostApplicationBuilder builder, object serviceKey, OllamaSharpSettings settings)
{
ArgumentNullException.ThrowIfNull(serviceKey, nameof(serviceKey));
ArgumentNullException.ThrowIfNull(builder, nameof(builder));
ArgumentNullException.ThrowIfNull(settings, nameof(settings));
return AddOllamaClientInternal(builder, DefaultConfigSectionName, serviceKey.ToString() ?? "default", serviceKey: serviceKey, configureSettings: null, settings: settings);
}

/// <summary>
/// Adds <see cref="IOllamaApiClient"/> and <see cref="IChatClient"/> services to the container.
/// </summary>
Expand Down Expand Up @@ -105,11 +135,15 @@ private static AspireOllamaApiClientBuilder AddOllamaClientInternal(
IHostApplicationBuilder builder,
string configurationSectionName,
string connectionName,
string? serviceKey = null,
Action<OllamaSharpSettings>? configureSettings = null)
object? serviceKey = null,
Action<OllamaSharpSettings>? configureSettings = null,
OllamaSharpSettings? settings = null)
{
OllamaSharpSettings settings = new();
builder.Configuration.GetSection(configurationSectionName).Bind(settings);
settings ??= new();
if (string.IsNullOrEmpty(settings.Endpoint?.ToString()))
{
builder.Configuration.GetSection(configurationSectionName).Bind(settings);
}

if (builder.Configuration.GetConnectionString(connectionName) is string connectionString)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,92 @@ public void CanSetMultipleKeyedClients()
Assert.NotEqual(client, client3);
}

[Fact]
public void CanSetMultipleKeyedClientsWithCustomServiceKeys()
{
var builder = Host.CreateEmptyApplicationBuilder(null);
builder.Configuration.AddInMemoryCollection([
new KeyValuePair<string, string?>("ConnectionStrings:Ollama", $"Endpoint={Endpoint}"),
new KeyValuePair<string, string?>("ConnectionStrings:Ollama2", "Endpoint=https://localhost:5002/"),
new KeyValuePair<string, string?>("ConnectionStrings:Ollama3", "Endpoint=https://localhost:5003/")
]);

// Use custom service keys instead of connection names
builder.AddKeyedOllamaApiClient("ChatModel", "Ollama");
builder.AddKeyedOllamaApiClient("VisionModel", "Ollama2");
builder.AddKeyedOllamaApiClient("EmbeddingModel", "Ollama3");

using var host = builder.Build();
var chatClient = host.Services.GetRequiredKeyedService<IOllamaApiClient>("ChatModel");
var visionClient = host.Services.GetRequiredKeyedService<IOllamaApiClient>("VisionModel");
var embeddingClient = host.Services.GetRequiredKeyedService<IOllamaApiClient>("EmbeddingModel");

Assert.Equal(Endpoint, chatClient.Uri);
Assert.Equal("https://localhost:5002/", visionClient.Uri?.ToString());
Assert.Equal("https://localhost:5003/", embeddingClient.Uri?.ToString());

Assert.NotEqual(chatClient, visionClient);
Assert.NotEqual(chatClient, embeddingClient);
Assert.NotEqual(visionClient, embeddingClient);
}

[Fact]
public void CanSetKeyedClientWithSettingsOverload()
{
var builder = Host.CreateEmptyApplicationBuilder(null);

var settings = new OllamaSharpSettings
{
Endpoint = Endpoint,
SelectedModel = "testmodel"
};

builder.AddKeyedOllamaApiClient("TestService", settings);

using var host = builder.Build();
var client = host.Services.GetRequiredKeyedService<IOllamaApiClient>("TestService");

Assert.Equal(Endpoint, client.Uri);
Assert.Equal("testmodel", client.SelectedModel);
}

[Fact]
public void CanUseSameConnectionWithDifferentServiceKeys()
{
// This test demonstrates the main use case from the issue:
// Using the same connection but different service keys for different models
var builder = Host.CreateEmptyApplicationBuilder(null);
builder.Configuration.AddInMemoryCollection([
new KeyValuePair<string, string?>("ConnectionStrings:LocalAI", $"Endpoint={Endpoint}")
]);

// Same connection, different service keys and models
builder.AddKeyedOllamaApiClient("ChatModel", "LocalAI", settings =>
{
settings.SelectedModel = "llama3.2";
});

builder.AddKeyedOllamaApiClient("VisionModel", "LocalAI", settings =>
{
settings.SelectedModel = "llava";
});

using var host = builder.Build();
var chatClient = host.Services.GetRequiredKeyedService<IOllamaApiClient>("ChatModel");
var visionClient = host.Services.GetRequiredKeyedService<IOllamaApiClient>("VisionModel");

// Both use the same endpoint
Assert.Equal(Endpoint, chatClient.Uri);
Assert.Equal(Endpoint, visionClient.Uri);

// But have different models
Assert.Equal("llama3.2", chatClient.SelectedModel);
Assert.Equal("llava", visionClient.SelectedModel);

// And are different instances
Assert.NotEqual(chatClient, visionClient);
}

[Fact]
public void RegisteringChatClientAndEmbeddingGeneratorReturnsCorrectModelForServices()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,83 @@ public void CanChainUseMethodsCorrectly()

using var host = builder.Build();
var client = host.Services.GetRequiredService<IChatClient>();

var distributedCacheClient = Assert.IsType<DistributedCachingChatClient>(client);
var functionInvocationClient = Assert.IsType<FunctionInvokingChatClient>(GetInnerClient(distributedCacheClient));
var otelClient = Assert.IsType<OpenTelemetryChatClient>(GetInnerClient(functionInvocationClient));

Assert.IsType<IOllamaApiClient>(GetInnerClient(otelClient), exactMatch: false);
}

[Fact]
public void CanSetMultipleKeyedChatClientsWithCustomServiceKeys()
{
var builder = Host.CreateEmptyApplicationBuilder(null);
builder.Configuration.AddInMemoryCollection([
new KeyValuePair<string, string?>("ConnectionStrings:Ollama", $"Endpoint={Endpoint}"),
new KeyValuePair<string, string?>("ConnectionStrings:Ollama2", "Endpoint=https://localhost:5002/")
]);

// Use custom service keys for different chat clients
builder.AddKeyedOllamaApiClient("ChatModel", "Ollama").AddKeyedChatClient();
builder.AddKeyedOllamaApiClient("VisionModel", "Ollama2").AddKeyedChatClient();

using var host = builder.Build();
var chatClient = host.Services.GetRequiredKeyedService<IChatClient>("ChatModel");
var visionClient = host.Services.GetRequiredKeyedService<IChatClient>("VisionModel");

Assert.Equal(Endpoint, chatClient.GetService<ChatClientMetadata>()?.ProviderUri);
Assert.Equal("https://localhost:5002/", visionClient.GetService<ChatClientMetadata>()?.ProviderUri?.ToString());

Assert.NotEqual(chatClient, visionClient);
}

[Fact]
public void CanSetMultipleChatClientsWithDifferentServiceKeys()
{
var builder = Host.CreateEmptyApplicationBuilder(null);
builder.Configuration.AddInMemoryCollection([
new KeyValuePair<string, string?>("ConnectionStrings:Ollama", $"Endpoint={Endpoint}")
]);

// Use one Ollama API client with multiple chat clients using different service keys
var cb = builder.AddKeyedOllamaApiClient("OllamaKey", "Ollama");
cb.AddKeyedChatClient("ChatKey1");
cb.AddKeyedChatClient("ChatKey2");

using var host = builder.Build();
var chatClient1 = host.Services.GetRequiredKeyedService<IChatClient>("ChatKey1");
var chatClient2 = host.Services.GetRequiredKeyedService<IChatClient>("ChatKey2");

Assert.Equal(Endpoint, chatClient1.GetService<ChatClientMetadata>()?.ProviderUri);
Assert.Equal(Endpoint, chatClient2.GetService<ChatClientMetadata>()?.ProviderUri);

Assert.NotEqual(chatClient1, chatClient2);
}

[Fact]
public void CanMixChatClientsAndEmbeddingGeneratorsWithCustomServiceKeys()
{
var builder = Host.CreateEmptyApplicationBuilder(null);
builder.Configuration.AddInMemoryCollection([
new KeyValuePair<string, string?>("ConnectionStrings:Ollama", $"Endpoint={Endpoint}")
]);

// Use one Ollama API client with both chat clients and embedding generators using different service keys
var cb = builder.AddKeyedOllamaApiClient("OllamaKey", "Ollama");
cb.AddKeyedChatClient("ChatKey1");
cb.AddKeyedEmbeddingGenerator("EmbeddingKey1");

using var host = builder.Build();
var chatClient1 = host.Services.GetRequiredKeyedService<IChatClient>("ChatKey1");
var embeddingGenerator = host.Services.GetRequiredKeyedService<IEmbeddingGenerator<string, Embedding<float>>>("EmbeddingKey1");

Assert.Equal(Endpoint, chatClient1.GetService<ChatClientMetadata>()?.ProviderUri);
Assert.Equal(Endpoint, embeddingGenerator.GetService<EmbeddingGeneratorMetadata>()?.ProviderUri);

Assert.Equal(chatClient1 as IOllamaApiClient, embeddingGenerator as IOllamaApiClient);
}

[UnsafeAccessor(UnsafeAccessorKind.Method, Name = "get_InnerClient")]
private static extern IChatClient GetInnerClient(DelegatingChatClient client);
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,32 @@ public void CanChainUseMethodsCorrectly()
Assert.IsType<IOllamaApiClient>(GetInnerGenerator(otelClient), exactMatch: false);
}

[Fact]
public void CanSetMultipleEmbeddingGeneratorsWithDifferentServiceKeys()
{
var builder = Host.CreateEmptyApplicationBuilder(null);
builder.Configuration.AddInMemoryCollection([
new KeyValuePair<string, string?>("ConnectionStrings:Ollama", $"Endpoint={Endpoint}")
]);

// Use one Ollama API client with multiple embedding generators using different service keys
var cb = builder.AddKeyedOllamaApiClient("OllamaKey", "Ollama");
cb.AddKeyedEmbeddingGenerator("EmbedKey1");
cb.AddKeyedEmbeddingGenerator("EmbedKey2");

using var host = builder.Build();
var embedGenerator1 = host.Services.GetRequiredKeyedService<IEmbeddingGenerator<string, Embedding<float>>>("EmbedKey1");
var embedGenerator2 = host.Services.GetRequiredKeyedService<IEmbeddingGenerator<string, Embedding<float>>>("EmbedKey2");

Assert.Equal(Endpoint, embedGenerator1.GetService<EmbeddingGeneratorMetadata>()?.ProviderUri);
Assert.Equal(Endpoint, embedGenerator2.GetService<EmbeddingGeneratorMetadata>()?.ProviderUri);

Assert.NotEqual(embedGenerator1, embedGenerator2);
}

private static IEmbeddingGenerator<TInput, TEmbedding> GetInnerGenerator<TInput, TEmbedding>(DelegatingEmbeddingGenerator<TInput, TEmbedding> generator)
where TEmbedding : Embedding =>
(IEmbeddingGenerator<TInput,TEmbedding>)(generator.GetType()
(IEmbeddingGenerator<TInput, TEmbedding>)(generator.GetType()
.GetProperty("InnerGenerator", BindingFlags.Instance | BindingFlags.NonPublic)?
.GetValue(generator, null) ?? throw new InvalidOperationException());
}
Loading