Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fb
  • Loading branch information
christothes committed May 2, 2025
commit 385dd353f3acee34173b05e425455a02c8f86463
32 changes: 27 additions & 5 deletions src/Utility/ChatTools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -204,22 +204,44 @@ public ChatCompletionOptions CreateCompletionOptions(string prompt, int maxTools
return CreateCompletionOptions();

var completionOptions = new ChatCompletionOptions();
foreach (var tool in FindRelatedTools(prompt, maxTools, minVectorDistance))
foreach (var tool in FindRelatedTools(false, prompt, maxTools, minVectorDistance).GetAwaiter().GetResult())
completionOptions.Tools.Add(tool);
return completionOptions;
}

private IEnumerable<ChatTool> FindRelatedTools(string prompt, int maxTools, float minVectorDistance)
/// <summary>
/// Converts the tools collection to <see cref="ChatCompletionOptions"/>, filtered by relevance to the given prompt.
/// </summary>
/// <param name="prompt">The prompt to find relevant tools for.</param>
/// <param name="maxTools">The maximum number of tools to return. Default is 3.</param>
/// <param name="minVectorDistance">The similarity threshold for including tools. Default is 0.29.</param>
/// <returns>A new <see cref="ChatCompletionOptions"/> containing the most relevant tools.</returns>
public async Task<ChatCompletionOptions> CreateCompletionOptionsAsync(string prompt, int maxTools = 3, float minVectorDistance = 0.29f)
{
if (!CanFilterTools)
return CreateCompletionOptions();

var completionOptions = new ChatCompletionOptions();
foreach (var tool in await FindRelatedTools(true, prompt, maxTools, minVectorDistance).ConfigureAwait(false))
completionOptions.Tools.Add(tool);
return completionOptions;
}

private async Task<IEnumerable<ChatTool>> FindRelatedTools(bool async, string prompt, int maxTools, float minVectorDistance)
{
if (!CanFilterTools)
return _tools;

return FindVectorMatches(prompt, maxTools, minVectorDistance).Select(e => ParseToolDefinition(e.Data));
return (await FindVectorMatches(async, prompt, maxTools, minVectorDistance).ConfigureAwait(false))
.Select(e => ParseToolDefinition(e.Data));
}

private IEnumerable<VectorbaseEntry> FindVectorMatches(string prompt, int maxTools, float minVectorDistance)
private async Task<IEnumerable<VectorbaseEntry>> FindVectorMatches(bool async, string prompt, int maxTools, float minVectorDistance)
{
var vector = ToolsUtility.GetEmbedding(_client, prompt).GetAwaiter().GetResult();
var vector = async ?
await ToolsUtility.GetEmbeddingAsync(_client, prompt).ConfigureAwait(false) :
ToolsUtility.GetEmbedding(_client, prompt);

lock (_entries)
{
var distances = _entries
Expand Down
2 changes: 1 addition & 1 deletion src/Utility/MCP/McpSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ public void Dispose()
_cancellationSource.Dispose();
}

public struct SseEvent
internal struct SseEvent
{
public string Event { get; set; }
public string Data { get; set; }
Expand Down
31 changes: 26 additions & 5 deletions src/Utility/ResponseTools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -207,22 +207,43 @@ public ResponseCreationOptions CreateResponseOptions(string prompt, int maxTools
return CreateResponseOptions();

var completionOptions = new ResponseCreationOptions();
foreach (var tool in FindRelatedTools(prompt, maxTools, minVectorDistance))
foreach (var tool in FindRelatedTools(false, prompt, maxTools, minVectorDistance).GetAwaiter().GetResult())
completionOptions.Tools.Add(tool);
return completionOptions;
}

private IEnumerable<ResponseTool> FindRelatedTools(string prompt, int maxTools, float minVectorDistance)
/// <summary>
/// Converts the tools collection to <see cref="ResponseCreationOptions">, filtered by relevance to the given prompt.
/// </summary>
/// <param name="prompt">The prompt to find relevant tools for.</param>
/// <param name="maxTools">The maximum number of tools to return. Default is 5.</param>
/// <param name="minVectorDistance">The similarity threshold for including tools. Default is 0.29.</param>
/// <returns>A new ResponseCreationOptions containing the most relevant tools.</returns>
public async Task<ResponseCreationOptions> CreateResponseOptionsAsync(string prompt, int maxTools = 5, float minVectorDistance = 0.29f)
{
if (!CanFilterTools)
return CreateResponseOptions();

var completionOptions = new ResponseCreationOptions();
foreach (var tool in await FindRelatedTools(true, prompt, maxTools, minVectorDistance).ConfigureAwait(false))
completionOptions.Tools.Add(tool);
return completionOptions;
}

private async Task<IEnumerable<ResponseTool>> FindRelatedTools(bool async, string prompt, int maxTools, float minVectorDistance)
{
if (!CanFilterTools)
return _tools;

return FindVectorMatches(prompt, maxTools, minVectorDistance).Select(e => ParseToolDefinition(e.Data));
return (await FindVectorMatches(async, prompt, maxTools, minVectorDistance).ConfigureAwait(false))
.Select(e => ParseToolDefinition(e.Data));
}

private IEnumerable<VectorbaseEntry> FindVectorMatches(string prompt, int maxTools, float minVectorDistance)
private async Task<IEnumerable<VectorbaseEntry>> FindVectorMatches(bool async, string prompt, int maxTools, float minVectorDistance)
{
var vector = ToolsUtility.GetEmbedding(_client, prompt).GetAwaiter().GetResult();
var vector = async ?
await ToolsUtility.GetEmbeddingAsync(_client, prompt).ConfigureAwait(false) :
ToolsUtility.GetEmbedding(_client, prompt);
lock (_entries)
{
var distances = _entries
Expand Down
8 changes: 7 additions & 1 deletion src/Utility/ToolsUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,18 @@ internal static BinaryData BuildParametersJson(ParameterInfo[] parameters)
return BinaryData.FromStream(stream);
}

internal static async Task<ReadOnlyMemory<float>> GetEmbedding(EmbeddingClient client, string text)
internal static async Task<ReadOnlyMemory<float>> GetEmbeddingAsync(EmbeddingClient client, string text)
{
var result = await client.GenerateEmbeddingAsync(text).ConfigureAwait(false);
return result.Value.ToFloats();
}

internal static ReadOnlyMemory<float> GetEmbedding(EmbeddingClient client, string text)
{
var result = client.GenerateEmbedding(text);
return result.Value.ToFloats();
}

internal static float CosineSimilarity(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
{
float dot = 0, xSumSquared = 0, ySumSquared = 0;
Expand Down
6 changes: 3 additions & 3 deletions tests/Utility/ChatToolsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public async Task CanFilterToolsByRelevance()
var tools = new ChatTools(mockEmbeddingClient.Object);
tools.AddLocalTools(typeof(TestTools));

var options = await Task.Run(() => tools.CreateCompletionOptions("Need to add two numbers", 1, 0.5f));
var options = await tools.CreateCompletionOptionsAsync("Need to add two numbers", 1, 0.5f);

Assert.That(options.Tools, Has.Count.LessThanOrEqualTo(1));
}
Expand Down Expand Up @@ -306,11 +306,11 @@ public async Task CreateCompletionOptions_WithMaxToolsParameter_FiltersTools()

// Act & Assert
// Test with maxTools = 1
var options1 = await Task.Run(() => tools.CreateCompletionOptions("calculate 2+2", 1, 0.5f));
var options1 = await tools.CreateCompletionOptionsAsync("calculate 2+2", 1, 0.5f);
Assert.That(options1.Tools, Has.Count.EqualTo(1));

// Test with maxTools = 2
var options2 = await Task.Run(() => tools.CreateCompletionOptions("calculate 2+2", 2, 0.5f));
var options2 = await tools.CreateCompletionOptionsAsync("calculate 2+2", 2, 0.5f);
Assert.That(options2.Tools, Has.Count.EqualTo(2));

// Test that we can call the tools after filtering
Expand Down
15 changes: 6 additions & 9 deletions tests/Utility/ResponseToolsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public async Task CanFilterToolsByRelevance()
var tools = new ResponseTools(mockEmbeddingClient.Object);
tools.AddLocalTools(typeof(TestTools));

var options = await Task.Run(() => tools.CreateResponseOptions("Need to add two numbers", 1, 0.5f));
var options = await tools.CreateResponseOptionsAsync("Need to add two numbers", 1, 0.5f);

Assert.That(options.Tools, Has.Count.LessThanOrEqualTo(1));
}
Expand Down Expand Up @@ -310,20 +310,17 @@ public async Task CreateResponseOptions_WithMaxToolsParameter_FiltersTools()

// Act & Assert
// Test with maxTools = 1
var options1 = await Task.Run(() => tools.CreateResponseOptions("calculate 2+2", 1, 0.5f));
var options1 = await tools.CreateResponseOptionsAsync("calculate 2+2", 1, 0.5f);
Assert.That(options1.Tools, Has.Count.EqualTo(1));

// Test with maxTools = 2
var options2 = await Task.Run(() => tools.CreateResponseOptions("calculate 2+2", 2, 0.5f));
var options2 = await tools.CreateResponseOptionsAsync("calculate 2+2", 2, 0.5f);
Assert.That(options2.Tools, Has.Count.EqualTo(2));

// Test that tool choice affects results
var optionsWithToolChoice = await Task.Run(() =>
{
var opts = tools.CreateResponseOptions("calculate 2+2", 1, 0.5f);
opts.ToolChoice = ResponseToolChoice.CreateRequiredChoice();
return opts;
});
var optionsWithToolChoice = await tools.CreateResponseOptionsAsync("calculate 2+2", 1, 0.5f);
optionsWithToolChoice.ToolChoice = ResponseToolChoice.CreateRequiredChoice();

Assert.That(optionsWithToolChoice.ToolChoice, Is.Not.Null);
Assert.That(optionsWithToolChoice.Tools, Has.Count.EqualTo(1));

Expand Down