Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
async tools
  • Loading branch information
christothes committed May 6, 2025
commit c984bc8fab8d26c0aeb374b9c4e1179cb6f34761
18 changes: 5 additions & 13 deletions src/Utility/ChatTools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace OpenAI.Chat;
/// <summary>
/// Provides functionality to manage and execute OpenAI function tools for chat completions.
/// </summary>
//[Experimental("OPENAIMCP001")]
public class ChatTools
{
private readonly Dictionary<string, MethodInfo> _methods = [];
Expand Down Expand Up @@ -223,26 +224,17 @@ await ToolsUtility.GetEmbeddingAsync(_client, prompt).ConfigureAwait(false) :
}
}

internal string CallLocal(ChatToolCall call)
internal async Task<string> CallFunctionToolAsync(ChatToolCall call)
{
var arguments = new List<object>();
if (call.FunctionArguments != null)
{
if (!_methods.TryGetValue(call.FunctionName, out MethodInfo method))
return $"I don't have a tool called {call.FunctionName}";
throw new InvalidOperationException($"Tool not found: {call.FunctionName}");

ToolsUtility.ParseFunctionCallArgs(method, call.FunctionArguments, out arguments);
}
return CallLocal(call.FunctionName, [.. arguments]);
}

private string CallLocal(string name, object[] arguments)
{
if (!_methods.TryGetValue(name, out MethodInfo method))
return $"I don't have a tool called {name}";

object result = method.Invoke(null, arguments);
return result?.ToString() ?? string.Empty;
return await ToolsUtility.CallFunctionToolAsync(_methods, call.FunctionName, [.. arguments]);
}

internal async Task<string> CallMcpAsync(ChatToolCall call)
Expand Down Expand Up @@ -285,7 +277,7 @@ public async Task<IEnumerable<ToolChatMessage>> CallAsync(IEnumerable<ChatToolCa
}
}

var result = isMcpTool ? await CallMcpAsync(toolCall).ConfigureAwait(false) : CallLocal(toolCall);
var result = isMcpTool ? await CallMcpAsync(toolCall).ConfigureAwait(false) : await CallFunctionToolAsync(toolCall).ConfigureAwait(false);
messages.Add(new ToolChatMessage(toolCall.Id, result));
}

Expand Down
1 change: 1 addition & 0 deletions src/Utility/MCP/McpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ namespace OpenAI.Agents;
/// <summary>
/// Client for interacting with a Model Context Protocol (MCP) server.
/// </summary>
//[Experimental("OPENAIMCP001")]
public class McpClient
{
private readonly McpSession _session;
Expand Down
21 changes: 7 additions & 14 deletions src/Utility/ResponseTools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ namespace OpenAI.Responses;
/// <summary>
/// Provides functionality to manage and execute OpenAI function tools for responses.
/// </summary>
public class ResponseTools
//[Experimental("OPENAIMCP001")
public class CallLocalAsync
{
private readonly Dictionary<string, MethodInfo> _methods = [];
private readonly Dictionary<string, Func<string, BinaryData, Task<BinaryData>>> _mcpMethods = [];
Expand All @@ -29,7 +30,7 @@ public class ResponseTools
/// Initializes a new instance of the ResponseTools class with an optional embedding client.
/// </summary>
/// <param name="client">The embedding client used for tool vectorization, or null to disable vectorization.</param>
public ResponseTools(EmbeddingClient client = null)
public CallLocalAsync(EmbeddingClient client = null)
{
_client = client;
}
Expand All @@ -38,7 +39,7 @@ public ResponseTools(EmbeddingClient client = null)
/// Initializes a new instance of the ResponseTools class with the specified tool types.
/// </summary>
/// <param name="tools">Additional tool types to add.</param>
public ResponseTools(params Type[] tools) : this((EmbeddingClient)null)
public CallLocalAsync(params Type[] tools) : this((EmbeddingClient)null)
{
foreach (var t in tools)
AddFunctionTool(t);
Expand Down Expand Up @@ -225,7 +226,7 @@ await ToolsUtility.GetEmbeddingAsync(_client, prompt).ConfigureAwait(false) :
}
}

internal string CallLocal(FunctionCallResponseItem call)
internal async Task<string> CallFunctionToolAsync(FunctionCallResponseItem call)
{
List<object> arguments = new();
if (call.FunctionArguments != null)
Expand All @@ -235,16 +236,8 @@ internal string CallLocal(FunctionCallResponseItem call)

ToolsUtility.ParseFunctionCallArgs(method, call.FunctionArguments, out arguments);
}
return CallLocal(call.FunctionName, [.. arguments]);
}

private string CallLocal(string name, object[] arguments)
{
if (!_methods.TryGetValue(name, out MethodInfo method))
return $"I don't have a tool called {name}";

object result = method.Invoke(null, arguments);
return result?.ToString() ?? string.Empty;
return await ToolsUtility.CallFunctionToolAsync(_methods, call.FunctionName, [.. arguments]);
}

internal async Task<string> CallMcpAsync(FunctionCallResponseItem call)
Expand Down Expand Up @@ -282,7 +275,7 @@ public async Task<FunctionCallOutputResponseItem> CallAsync(FunctionCallResponse
}
}

var result = isMcpTool ? await CallMcpAsync(toolCall).ConfigureAwait(false) : CallLocal(toolCall);
var result = isMcpTool ? await CallMcpAsync(toolCall).ConfigureAwait(false) : await CallFunctionToolAsync(toolCall);
return new FunctionCallOutputResponseItem(toolCall.CallId, result);
}
}
Expand Down
40 changes: 40 additions & 0 deletions src/Utility/ToolsUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,44 @@ internal static BinaryData SerializeTool(string name, string description, Binary
stream.Position = 0;
return BinaryData.FromStream(stream);
}

internal static async Task<string> CallFunctionToolAsync(Dictionary<string, MethodInfo> methods, string name, object[] arguments)
{
if (!methods.TryGetValue(name, out MethodInfo method))
throw new InvalidOperationException($"Tool not found: {name}");

object result;
if (IsGenericTask(method.ReturnType, out Type taskResultType))
{
// Method is async, invoke and await
var task = (Task)method.Invoke(null, arguments);
await task.ConfigureAwait(false);
// Get the Result property from the Task
result = taskResultType.GetProperty("Result").GetValue(task);
}
else
{
// Method is synchronous
result = method.Invoke(null, arguments);
}

return result?.ToString() ?? string.Empty;
}

private static bool IsGenericTask(Type type, out Type taskResultType)
{
while (type != null && type != typeof(object))
{
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Task<>))
{
taskResultType = type;//type.GetGenericArguments()[0];
return true;
}

type = type.BaseType!;
}

taskResultType = null;
return false;
}
}
88 changes: 88 additions & 0 deletions tests/Utility/ChatToolsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,45 @@ private class TestTools
public static string ConcatWithBool(string text, bool flag) => $"{text}:{flag}";
}

private class TestToolsAsync
{
public static async Task<string> EchoAsync(string message)
{
await Task.Delay(1); // Simulate async work
return message;
}

public static async Task<int> AddAsync(int a, int b)
{
await Task.Delay(1); // Simulate async work
return a + b;
}

public static async Task<double> MultiplyAsync(double x, double y)
{
await Task.Delay(1); // Simulate async work
return x * y;
}

public static async Task<bool> IsGreaterThanAsync(long value1, long value2)
{
await Task.Delay(1); // Simulate async work
return value1 > value2;
}

public static async Task<float> DivideAsync(float numerator, float denominator)
{
await Task.Delay(1); // Simulate async work
return numerator / denominator;
}

public static async Task<string> ConcatWithBoolAsync(string text, bool flag)
{
await Task.Delay(1); // Simulate async work
return $"{text}:{flag}";
}
}

private Mock<EmbeddingClient> mockEmbeddingClient;

[SetUp]
Expand All @@ -52,6 +91,21 @@ public void CanAddLocalTools()
Assert.That(tools.Tools.Any(t => t.FunctionName == "ConcatWithBool"));
}

[Test]
public void CanAddAsyncLocalTools()
{
var tools = new ChatTools();
tools.AddFunctionTools(typeof(TestToolsAsync));

Assert.That(tools.Tools, Has.Count.EqualTo(6));
Assert.That(tools.Tools.Any(t => t.FunctionName == "EchoAsync"));
Assert.That(tools.Tools.Any(t => t.FunctionName == "AddAsync"));
Assert.That(tools.Tools.Any(t => t.FunctionName == "MultiplyAsync"));
Assert.That(tools.Tools.Any(t => t.FunctionName == "IsGreaterThanAsync"));
Assert.That(tools.Tools.Any(t => t.FunctionName == "DivideAsync"));
Assert.That(tools.Tools.Any(t => t.FunctionName == "ConcatWithBoolAsync"));
}

[Test]
public async Task CanCallToolsAsync()
{
Expand Down Expand Up @@ -86,6 +140,40 @@ public async Task CanCallToolsAsync()
Assert.That(resultsList[5].Content[0].Text, Is.EqualTo("Test:True"));
}

[Test]
public async Task CanCallAsyncToolsAsync()
{
var tools = new ChatTools();
tools.AddFunctionTools(typeof(TestToolsAsync));

var toolCalls = new[]
{
ChatToolCall.CreateFunctionToolCall("call1", "EchoAsync", BinaryData.FromString(@"{""message"": ""Hello""}")),
ChatToolCall.CreateFunctionToolCall("call2", "AddAsync", BinaryData.FromString(@"{""a"": 2, ""b"": 3}")),
ChatToolCall.CreateFunctionToolCall("call3", "MultiplyAsync", BinaryData.FromString(@"{""x"": 2.5, ""y"": 3.0}")),
ChatToolCall.CreateFunctionToolCall("call4", "IsGreaterThanAsync", BinaryData.FromString(@"{""value1"": 100, ""value2"": 50}")),
ChatToolCall.CreateFunctionToolCall("call5", "DivideAsync", BinaryData.FromString(@"{""numerator"": 10.0, ""denominator"": 2.0}")),
ChatToolCall.CreateFunctionToolCall("call6", "ConcatWithBoolAsync", BinaryData.FromString(@"{""text"": ""Test"", ""flag"": true}"))
};

var results = await tools.CallAsync(toolCalls);
var resultsList = results.ToList();

Assert.That(resultsList, Has.Count.EqualTo(6));
Assert.That(resultsList[0].ToolCallId, Is.EqualTo("call1"));
Assert.That(resultsList[0].Content[0].Text, Is.EqualTo("Hello"));
Assert.That(resultsList[1].ToolCallId, Is.EqualTo("call2"));
Assert.That(resultsList[1].Content[0].Text, Is.EqualTo("5"));
Assert.That(resultsList[2].ToolCallId, Is.EqualTo("call3"));
Assert.That(resultsList[2].Content[0].Text, Is.EqualTo("7.5"));
Assert.That(resultsList[3].ToolCallId, Is.EqualTo("call4"));
Assert.That(resultsList[3].Content[0].Text, Is.EqualTo("True"));
Assert.That(resultsList[4].ToolCallId, Is.EqualTo("call5"));
Assert.That(resultsList[4].Content[0].Text, Is.EqualTo("5"));
Assert.That(resultsList[5].ToolCallId, Is.EqualTo("call6"));
Assert.That(resultsList[5].Content[0].Text, Is.EqualTo("Test:True"));
}

[Test]
public void CreatesCompletionOptionsWithTools()
{
Expand Down
Loading