Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feedback
  • Loading branch information
christothes committed May 1, 2025
commit dcc655cef51a42547f8ea791fb0dc580cec7d0d6
61 changes: 12 additions & 49 deletions src/Utility/ChatTools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace OpenAI.Chat;
/// <summary>
/// Provides functionality to manage and execute OpenAI function tools for chat completions.
/// </summary>
public class ChatTools : ToolsBase<ChatTool>
public class ChatTools : Tools<ChatTool>
{
/// <summary>
/// Initializes a new instance of the ChatTools class with an optional embedding client.
Expand All @@ -23,14 +23,11 @@ public ChatTools(EmbeddingClient client = null) : base(client) { }
/// <summary>
/// Initializes a new instance of the ChatTools class with the specified tool types.
/// </summary>
/// <param name="tool">The primary tool type to add.</param>
/// <param name="additionalTools">Additional tool types to add.</param>
public ChatTools(Type tool, params Type[] additionalTools) : this((EmbeddingClient)null)
public ChatTools(params Type[] additionalTools) : this((EmbeddingClient)null)
{
Add(tool);
if (additionalTools != null)
foreach (var t in additionalTools)
Add(t);
foreach (var t in additionalTools)
AddLocalTool(t);
}

internal override ChatTool MethodInfoToTool(MethodInfo methodInfo) =>
Expand All @@ -56,7 +53,7 @@ internal override async Task Add(BinaryData toolDefinitions, McpClient client)
#pragma warning restore IL2026, IL3050

var chatTool = ChatTool.CreateFunctionTool(name, description, BinaryData.FromString(inputSchema));
_definitions.Add(chatTool);
_tools.Add(chatTool);
toolsToVectorize.Add(chatTool);
_mcpMethods[name] = client.CallToolAsync;
}
Expand Down Expand Up @@ -99,10 +96,10 @@ protected override ChatTool ParseToolDefinition(BinaryData data)
/// Converts the tools collection to chat completion options.
/// </summary>
/// <returns>A new ChatCompletionOptions containing all defined tools.</returns>
public ChatCompletionOptions ToOptions()
public ChatCompletionOptions CreateCompletionOptions()
{
var options = new ChatCompletionOptions();
foreach (var tool in _definitions)
foreach (var tool in _tools)
options.Tools.Add(tool);
return options;
}
Expand All @@ -113,13 +110,13 @@ public ChatCompletionOptions ToOptions()
/// <param name="prompt">The prompt to find relevant tools for.</param>
/// <param name="options">Options for filtering tools, including maximum number of tools to return.</param>
/// <returns>A new <see cref="ChatCompletionOptions"/> containing the most relevant tools.</returns>
public ChatCompletionOptions ToOptions(string prompt, ToolFindOptions options = null)
public ChatCompletionOptions CreateCompletionOptions(string prompt, ToolSelectionOptions options = null)
{
if (!CanFilterTools)
return ToOptions();
return CreateCompletionOptions();

var completionOptions = new ChatCompletionOptions();
foreach (var tool in RelatedTo(prompt, options?.MaxEntries ?? 5))
foreach (var tool in RelatedTo(prompt, options?.MaxTools ?? 5))
completionOptions.Tools.Add(tool);
return completionOptions;
}
Expand All @@ -128,7 +125,7 @@ public ChatCompletionOptions ToOptions(string prompt, ToolFindOptions options =
/// Implicitly converts ChatTools to <see cref="ChatCompletionOptions"/>.
/// </summary>
/// <param name="tools">The ChatTools instance to convert.</param>
public static implicit operator ChatCompletionOptions(ChatTools tools) => tools.ToOptions();
public static implicit operator ChatCompletionOptions(ChatTools tools) => tools.CreateCompletionOptions();

internal string CallLocal(ChatToolCall call)
{
Expand Down Expand Up @@ -171,7 +168,7 @@ internal async Task<string> CallMcpAsync(ChatToolCall call)
/// </summary>
/// <param name="toolCalls">The collection of tool calls to execute.</param>
/// <returns>A collection of tool chat messages containing the results.</returns>
public async Task<IEnumerable<ToolChatMessage>> CallAllAsync(IEnumerable<ChatToolCall> toolCalls)
public async Task<IEnumerable<ToolChatMessage>> CallAsync(IEnumerable<ChatToolCall> toolCalls)
{
var messages = new List<ToolChatMessage>();
foreach (ChatToolCall toolCall in toolCalls)
Expand All @@ -195,39 +192,5 @@ public async Task<IEnumerable<ToolChatMessage>> CallAllAsync(IEnumerable<ChatToo

return messages;
}

/// <summary>
/// Executes all tool calls and returns both results and any failed tool names.
/// </summary>
/// <param name="toolCalls">The collection of tool calls to execute.</param>
/// <returns>A result object containing successful tool messages and failed tool names.</returns>
public async Task<ToolCallChatResult> CallAllWithErrorsAsync(IEnumerable<ChatToolCall> toolCalls)
{
List<string> failed = null;
var messages = new List<ToolChatMessage>();

foreach (ChatToolCall toolCall in toolCalls)
{
bool isMcpTool = false;
if (!_methods.ContainsKey(toolCall.FunctionName))
{
if (_mcpMethods.ContainsKey(toolCall.FunctionName))
{
isMcpTool = true;
}
else
{
failed ??= new();
failed.Add(toolCall.FunctionName);
continue;
}
}

var result = isMcpTool ? await CallMcpAsync(toolCall).ConfigureAwait(false) : CallLocal(toolCall);
messages.Add(new ToolChatMessage(toolCall.Id, result));
}

return new(messages, failed);
}
}

20 changes: 10 additions & 10 deletions src/Utility/ResponseTools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace OpenAI.Responses;
/// <summary>
/// Provides functionality to manage and execute OpenAI function tools for responses.
/// </summary>
public class ResponseTools : ToolsBase<ResponseTool>
public class ResponseTools : Tools<ResponseTool>
{
/// <summary>
/// Initializes a new instance of the ResponseTools class with an optional embedding client.
Expand All @@ -27,10 +27,10 @@ public ResponseTools(EmbeddingClient client = null) : base(client) { }
/// <param name="additionalTools">Additional tool types to add.</param>
public ResponseTools(Type tool, params Type[] additionalTools) : this((EmbeddingClient)null)
{
Add(tool);
AddLocalTool(tool);
if (additionalTools != null)
foreach (var t in additionalTools)
Add(t);
AddLocalTool(t);
}

internal override ResponseTool MethodInfoToTool(MethodInfo methodInfo) =>
Expand All @@ -56,7 +56,7 @@ internal override async Task Add(BinaryData toolDefinitions, McpClient client)
#pragma warning restore IL2026, IL3050

var responseTool = ResponseTool.CreateFunctionTool(name, description, BinaryData.FromString(inputSchema));
_definitions.Add(responseTool);
_tools.Add(responseTool);
toolsToVectorize.Add(responseTool);
_mcpMethods[name] = client.CallToolAsync;
}
Expand Down Expand Up @@ -100,10 +100,10 @@ protected override ResponseTool ParseToolDefinition(BinaryData data)
/// Converts the tools collection to <see cref="ResponseCreationOptions"> configured with the tools contained in this instance..
/// </summary>
/// <returns>A new ResponseCreationOptions containing all defined tools.</returns>
public ResponseCreationOptions ToOptions()
public ResponseCreationOptions CreateResponseOptions()
{
var options = new ResponseCreationOptions();
foreach (var tool in _definitions)
foreach (var tool in _tools)
options.Tools.Add(tool);
return options;
}
Expand All @@ -114,13 +114,13 @@ public ResponseCreationOptions ToOptions()
/// <param name="prompt">The prompt to find relevant tools for.</param>
/// <param name="options">Options for filtering tools, including maximum number of tools to return.</param>
/// <returns>A new ResponseCreationOptions containing the most relevant tools.</returns>
public ResponseCreationOptions ToOptions(string prompt, ToolFindOptions options = null)
public ResponseCreationOptions CreateResponseOptions(string prompt, ToolSelectionOptions options = null)
{
if (!CanFilterTools)
return ToOptions();
return CreateResponseOptions();

var completionOptions = new ResponseCreationOptions();
foreach (var tool in RelatedTo(prompt, options?.MaxEntries ?? 5))
foreach (var tool in RelatedTo(prompt, options?.MaxTools ?? 5))
completionOptions.Tools.Add(tool);
return completionOptions;
}
Expand All @@ -129,7 +129,7 @@ public ResponseCreationOptions ToOptions(string prompt, ToolFindOptions options
/// Implicitly converts ResponseTools to ResponseCreationOptions.
/// </summary>
/// <param name="tools">The ResponseTools instance to convert.</param>
public static implicit operator ResponseCreationOptions(ResponseTools tools) => tools.ToOptions();
public static implicit operator ResponseCreationOptions(ResponseTools tools) => tools.CreateResponseOptions();

internal string CallLocal(FunctionCallResponseItem call)
{
Expand Down
31 changes: 0 additions & 31 deletions src/Utility/ToolCallChatResult.cs

This file was deleted.

46 changes: 23 additions & 23 deletions src/Utility/ToolsBase.cs → src/Utility/Tools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,29 @@ namespace OpenAI;
/// Base class containing common functionality for tool management.
/// </summary>
/// <typeparam name="TTool">The concrete tool type (<see cref="ChatTool"/> or <see cref="ResponseTool"/>)</typeparam>
public abstract class ToolsBase<TTool> where TTool : class
public abstract class Tools<TTool> where TTool : class
{
protected static readonly BinaryData s_noparams = BinaryData.FromString("""{ "type" : "object", "properties" : {} }""");
internal static readonly BinaryData s_noparams = BinaryData.FromString("""{ "type" : "object", "properties" : {} }""");

protected readonly Dictionary<string, MethodInfo> _methods = [];
protected readonly Dictionary<string, Func<string, BinaryData, Task<BinaryData>>> _mcpMethods = [];
protected readonly List<TTool> _definitions = [];
protected readonly EmbeddingClient _client;
protected readonly List<VectorbaseEntry> _entries = [];
internal readonly Dictionary<string, MethodInfo> _methods = [];
internal readonly Dictionary<string, Func<string, BinaryData, Task<BinaryData>>> _mcpMethods = [];
internal readonly List<TTool> _tools = [];
internal readonly EmbeddingClient _client;
internal readonly List<VectorbaseEntry> _entries = [];

internal readonly List<McpClient> _mcpClients = [];
internal readonly Dictionary<string, McpClient> _mcpClientsByEndpoint = [];
protected const string _mcpToolSeparator = "_-_";
internal const string _mcpToolSeparator = "_-_";

protected ToolsBase(EmbeddingClient client = null)
protected Tools(EmbeddingClient client = null)
{
_client = client;
}

/// <summary>
/// Gets the list of defined tools.
/// </summary>
public IList<TTool> Definitions => _definitions;
public IList<TTool> ToolList => _tools;

/// <summary>
/// Gets whether tools can be filtered using embeddings provided by the provided <see cref="EmbeddingClient"/> .
Expand All @@ -50,7 +50,7 @@ protected ToolsBase(EmbeddingClient client = null)
public void AddLocalTools(params Type[] tools)
{
foreach (Type functionHolder in tools)
Add(functionHolder);
AddLocalTool(functionHolder);
}

/// <summary>
Expand Down Expand Up @@ -83,21 +83,21 @@ public async Task AddMcpServerAsync(Uri serverEndpoint)
/// Adds all public static methods from the specified type as tools.
/// </summary>
/// <param name="functions">The type containing tool methods.</param>
public void Add(Type functions)
public void AddLocalTool(Type functions)
{
#pragma warning disable IL2070
foreach (MethodInfo function in functions.GetMethods(BindingFlags.Public | BindingFlags.Static))
{
Add(function);
AddLocalTool(function);
}
#pragma warning restore IL2070
}

public void Add(MethodInfo function)
public void AddLocalTool(MethodInfo function)
{
string name = function.Name;
var tool = MethodInfoToTool(function);
_definitions.Add(tool);
_tools.Add(tool);
_methods[name] = function;
}

Expand Down Expand Up @@ -210,22 +210,22 @@ private async Task<ReadOnlyMemory<float>> GetEmbedding(string text)
protected IEnumerable<TTool> RelatedTo(string prompt, int maxEntries = 5)
{
if (!CanFilterTools)
return _definitions;
return _tools;

var options = new ToolFindOptions { MaxEntries = maxEntries };
var options = new ToolSelectionOptions { MaxTools = maxEntries };
return Find(prompt, options).Select(e => ParseToolDefinition(e.Data));
}

protected IEnumerable<VectorbaseEntry> Find(string prompt, ToolFindOptions options)
protected IEnumerable<VectorbaseEntry> Find(string prompt, ToolSelectionOptions options)
{
ReadOnlyMemory<float> vector = GetEmbedding(prompt).GetAwaiter().GetResult();
lock (_entries)
{
var distances = _entries
.Select((e, i) => (Distance: 1f - CosineSimilarity(e.Vector.Span, vector.Span), Index: i))
.OrderBy(t => t.Distance)
.Take(options.MaxEntries)
.Where(t => t.Distance <= options.Threshold);
.Take(options.MaxTools)
.Where(t => t.Distance <= options.MinVectorDistance);

return distances.Select(d => _entries[d.Index]);
}
Expand Down Expand Up @@ -256,16 +256,16 @@ private static float CosineSimilarity(ReadOnlySpan<float> x, ReadOnlySpan<float>
/// <summary>
/// Options for finding related tools.
/// </summary>
public class ToolFindOptions
public class ToolSelectionOptions
{
/// <summary>
/// Gets or sets the maximum number of tools to return. Default is 3.
/// </summary>
public int MaxEntries { get; set; } = 3;
public int MaxTools { get; set; } = 3;

/// <summary>
/// Gets or sets the similarity threshold for including tools. Default is 0.29.
/// </summary>
public float Threshold { get; set; } = 0.29f;
public float MinVectorDistance { get; set; } = 0.29f;
}
}