Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fb
  • Loading branch information
christothes committed May 5, 2025
commit 76b907af66de9552b9247144deb0d2ebfe4585a0
5 changes: 4 additions & 1 deletion src/Utility/ChatTools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,10 @@ internal string CallLocal(ChatToolCall call)
var arguments = new List<object>();
if (call.FunctionArguments != null)
{
ToolsUtility.ParseFunctionCallArgs(call.FunctionArguments, out arguments);
if (!_methods.TryGetValue(call.FunctionName, out MethodInfo method))
return $"I don't have a tool called {call.FunctionName}";

ToolsUtility.ParseFunctionCallArgs(method, call.FunctionArguments, out arguments);
}
return CallLocal(call.FunctionName, [.. arguments]);
}
Expand Down
5 changes: 4 additions & 1 deletion src/Utility/ResponseTools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,10 @@ internal string CallLocal(FunctionCallResponseItem call)
List<object> arguments = new();
if (call.FunctionArguments != null)
{
ToolsUtility.ParseFunctionCallArgs(call.FunctionArguments, out arguments);
if (!_methods.TryGetValue(call.FunctionName, out MethodInfo method))
return $"I don't have a tool called {call.FunctionName}";

ToolsUtility.ParseFunctionCallArgs(method, call.FunctionArguments, out arguments);
}
return CallLocal(call.FunctionName, [.. arguments]);
}
Expand Down
31 changes: 25 additions & 6 deletions src/Utility/ToolsUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ internal static ReadOnlySpan<byte> ClrToJsonTypeUtf8(Type clrType) =>
{
Type t when t == typeof(double) => "number"u8,
Type t when t == typeof(int) => "number"u8,
Type t when t == typeof(long) => "number"u8,
Type t when t == typeof(float) => "number"u8,
Type t when t == typeof(string) => "string"u8,
Type t when t == typeof(bool) => "bool"u8,
_ => throw new NotImplementedException()
Expand Down Expand Up @@ -146,19 +148,36 @@ internal static float CosineSimilarity(ReadOnlySpan<float> x, ReadOnlySpan<float
return result;
}

internal static void ParseFunctionCallArgs(BinaryData functionCallArguments, out List<object> arguments)
internal static void ParseFunctionCallArgs(MethodInfo method, BinaryData functionCallArguments, out List<object> arguments)
{
arguments = new List<object>();
using var document = JsonDocument.Parse(functionCallArguments);
foreach (JsonProperty argument in document.RootElement.EnumerateObject())
var parameters = method.GetParameters();
var argumentsByName = document.RootElement.EnumerateObject().ToDictionary(p => p.Name, p => p.Value);

foreach (var param in parameters)
{
arguments.Add(argument.Value.ValueKind switch
if (!argumentsByName.TryGetValue(param.Name!, out var value))
{
if (param.HasDefaultValue)
{
arguments.Add(param.DefaultValue!);
continue;
}
throw new JsonException($"Required parameter '{param.Name}' not found in function call arguments.");
}

arguments.Add(value.ValueKind switch
{
JsonValueKind.String => argument.Value.GetString()!,
JsonValueKind.Number => argument.Value.GetInt32(),
JsonValueKind.String => value.GetString()!,
JsonValueKind.Number when param.ParameterType == typeof(int) => value.GetInt32(),
JsonValueKind.Number when param.ParameterType == typeof(long) => value.GetInt64(),
JsonValueKind.Number when param.ParameterType == typeof(double) => value.GetDouble(),
JsonValueKind.Number when param.ParameterType == typeof(float) => value.GetSingle(),
JsonValueKind.True => true,
JsonValueKind.False => false,
_ => throw new NotImplementedException()
JsonValueKind.Null when param.HasDefaultValue => param.DefaultValue!,
_ => throw new NotImplementedException($"Conversion from {value.ValueKind} to {param.ParameterType.Name} is not implemented.")
});
}
}
Expand Down
32 changes: 28 additions & 4 deletions tests/Utility/ChatToolsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ private class TestTools
{
public static string Echo(string message) => message;
public static int Add(int a, int b) => a + b;
public static double Multiply(double x, double y) => x * y;
public static bool IsGreaterThan(long value1, long value2) => value1 > value2;
public static float Divide(float numerator, float denominator) => numerator / denominator;
public static string ConcatWithBool(string text, bool flag) => $"{text}:{flag}";
}

private Mock<EmbeddingClient> mockEmbeddingClient;
Expand All @@ -39,9 +43,13 @@ public void CanAddLocalTools()
var tools = new ChatTools();
tools.AddFunctionTools(typeof(TestTools));

Assert.That(tools.Tools, Has.Count.EqualTo(2));
Assert.That(tools.Tools, Has.Count.EqualTo(6));
Assert.That(tools.Tools.Any(t => t.FunctionName == "Echo"));
Assert.That(tools.Tools.Any(t => t.FunctionName == "Add"));
Assert.That(tools.Tools.Any(t => t.FunctionName == "Multiply"));
Assert.That(tools.Tools.Any(t => t.FunctionName == "IsGreaterThan"));
Assert.That(tools.Tools.Any(t => t.FunctionName == "Divide"));
Assert.That(tools.Tools.Any(t => t.FunctionName == "ConcatWithBool"));
}

[Test]
Expand All @@ -53,17 +61,29 @@ public async Task CanCallToolsAsync()
var toolCalls = new[]
{
ChatToolCall.CreateFunctionToolCall("call1", "Echo", BinaryData.FromString(@"{""message"": ""Hello""}")),
ChatToolCall.CreateFunctionToolCall("call2", "Add", BinaryData.FromString(@"{""a"": 2, ""b"": 3}"))
ChatToolCall.CreateFunctionToolCall("call2", "Add", BinaryData.FromString(@"{""a"": 2, ""b"": 3}")),
ChatToolCall.CreateFunctionToolCall("call3", "Multiply", BinaryData.FromString(@"{""x"": 2.5, ""y"": 3.0}")),
ChatToolCall.CreateFunctionToolCall("call4", "IsGreaterThan", BinaryData.FromString(@"{""value1"": 100, ""value2"": 50}")),
ChatToolCall.CreateFunctionToolCall("call5", "Divide", BinaryData.FromString(@"{""numerator"": 10.0, ""denominator"": 2.0}")),
ChatToolCall.CreateFunctionToolCall("call6", "ConcatWithBool", BinaryData.FromString(@"{""text"": ""Test"", ""flag"": true}"))
};

var results = await tools.CallAsync(toolCalls);
var resultsList = results.ToList();

Assert.That(resultsList, Has.Count.EqualTo(2));
Assert.That(resultsList, Has.Count.EqualTo(6));
Assert.That(resultsList[0].ToolCallId, Is.EqualTo("call1"));
Assert.That(resultsList[0].Content[0].Text, Is.EqualTo("Hello"));
Assert.That(resultsList[1].ToolCallId, Is.EqualTo("call2"));
Assert.That(resultsList[1].Content[0].Text, Is.EqualTo("5"));
Assert.That(resultsList[2].ToolCallId, Is.EqualTo("call3"));
Assert.That(resultsList[2].Content[0].Text, Is.EqualTo("7.5"));
Assert.That(resultsList[3].ToolCallId, Is.EqualTo("call4"));
Assert.That(resultsList[3].Content[0].Text, Is.EqualTo("True"));
Assert.That(resultsList[4].ToolCallId, Is.EqualTo("call5"));
Assert.That(resultsList[4].Content[0].Text, Is.EqualTo("5"));
Assert.That(resultsList[5].ToolCallId, Is.EqualTo("call6"));
Assert.That(resultsList[5].Content[0].Text, Is.EqualTo("Test:True"));
}

[Test]
Expand All @@ -74,9 +94,13 @@ public void CreatesCompletionOptionsWithTools()

var options = tools.ToChatCompletionOptions();

Assert.That(options.Tools, Has.Count.EqualTo(2));
Assert.That(options.Tools, Has.Count.EqualTo(6));
Assert.That(options.Tools.Any(t => t.FunctionName == "Echo"));
Assert.That(options.Tools.Any(t => t.FunctionName == "Add"));
Assert.That(options.Tools.Any(t => t.FunctionName == "Multiply"));
Assert.That(options.Tools.Any(t => t.FunctionName == "IsGreaterThan"));
Assert.That(options.Tools.Any(t => t.FunctionName == "Divide"));
Assert.That(options.Tools.Any(t => t.FunctionName == "ConcatWithBool"));
}

[Test]
Expand Down
61 changes: 49 additions & 12 deletions tests/Utility/ResponseToolsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ private class TestTools
{
public static string Echo(string message) => message;
public static int Add(int a, int b) => a + b;
public static double Multiply(double x, double y) => x * y;
public static bool IsGreaterThan(long value1, long value2) => value1 > value2;
public static float Divide(float numerator, float denominator) => numerator / denominator;
public static string ConcatWithBool(string text, bool flag) => $"{text}:{flag}";
}

private Mock<EmbeddingClient> mockEmbeddingClient;
Expand All @@ -37,9 +41,13 @@ public void CanAddLocalTools()
var tools = new ResponseTools();
tools.AddFunctionTools(typeof(TestTools));

Assert.That(tools.Tools, Has.Count.EqualTo(2));
Assert.That(tools.Tools, Has.Count.EqualTo(6));
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Echo")));
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Add")));
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Multiply")));
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("IsGreaterThan")));
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Divide")));
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("ConcatWithBool")));
}

[Test]
Expand All @@ -48,17 +56,42 @@ public async Task CanCallToolAsync()
var tools = new ResponseTools();
tools.AddFunctionTools(typeof(TestTools));

var toolCall = new FunctionCallResponseItem("call1", "Echo", BinaryData.FromString(@"{""message"": ""Hello""}"));
var result = await tools.CallAsync(toolCall);

Assert.That(result.CallId, Is.EqualTo("call1"));
Assert.That(result.FunctionOutput, Is.EqualTo("Hello"));

var addCall = new FunctionCallResponseItem("call2", "Add", BinaryData.FromString(@"{""a"": 2, ""b"": 3}"));
result = await tools.CallAsync(addCall);
var toolCalls = new[]
{
new FunctionCallResponseItem("call1", "Echo", BinaryData.FromString(@"{""message"": ""Hello""}")),
new FunctionCallResponseItem("call2", "Add", BinaryData.FromString(@"{""a"": 2, ""b"": 3}")),
new FunctionCallResponseItem("call3", "Multiply", BinaryData.FromString(@"{""x"": 2.5, ""y"": 3.0}")),
new FunctionCallResponseItem("call4", "IsGreaterThan", BinaryData.FromString(@"{""value1"": 100, ""value2"": 50}")),
new FunctionCallResponseItem("call5", "Divide", BinaryData.FromString(@"{""numerator"": 10.0, ""denominator"": 2.0}")),
new FunctionCallResponseItem("call6", "ConcatWithBool", BinaryData.FromString(@"{""text"": ""Test"", ""flag"": true}"))
};

Assert.That(result.CallId, Is.EqualTo("call2"));
Assert.That(result.FunctionOutput, Is.EqualTo("5"));
foreach (var toolCall in toolCalls)
{
var result = await tools.CallAsync(toolCall);
Assert.That(result.CallId, Is.EqualTo(toolCall.CallId));
switch (toolCall.CallId)
{
case "call1":
Assert.That(result.FunctionOutput, Is.EqualTo("Hello"));
break;
case "call2":
Assert.That(result.FunctionOutput, Is.EqualTo("5"));
break;
case "call3":
Assert.That(result.FunctionOutput, Is.EqualTo("7.5"));
break;
case "call4":
Assert.That(result.FunctionOutput, Is.EqualTo("True"));
break;
case "call5":
Assert.That(result.FunctionOutput, Is.EqualTo("5"));
break;
case "call6":
Assert.That(result.FunctionOutput, Is.EqualTo("Test:True"));
break;
}
}
}

[Test]
Expand All @@ -69,9 +102,13 @@ public void CreatesResponseOptionsWithTools()

var options = tools.ToResponseCreationOptions();

Assert.That(options.Tools, Has.Count.EqualTo(2));
Assert.That(options.Tools, Has.Count.EqualTo(6));
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Echo")));
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Add")));
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Multiply")));
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("IsGreaterThan")));
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Divide")));
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("ConcatWithBool")));
}

[Test]
Expand Down