Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
tests
  • Loading branch information
christothes committed May 2, 2025
commit 567ffa0316dc4e40fc50ae9f5e11b7291c9f755c
2 changes: 1 addition & 1 deletion src/Utility/ChatTools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public void AddLocalTool(MethodInfo function)
/// </summary>
/// <param name="client">The MCP client instance.</param>
/// <returns>A task representing the asynchronous operation.</returns>
internal async Task AddMcpToolsAsync(McpClient client)
public async Task AddMcpToolsAsync(McpClient client)
{
if (client == null) throw new ArgumentNullException(nameof(client));
_mcpClientsByEndpoint[client.ServerEndpoint.AbsoluteUri] = client;
Expand Down
29 changes: 28 additions & 1 deletion src/Utility/MCP/McpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,45 @@

namespace OpenAI;

internal class McpClient
/// <summary>
/// Client for interacting with a Model Context Protocol (MCP) server.
/// </summary>
public class McpClient
{
private readonly McpSession _session;
private readonly ClientPipeline _pipeline;

/// <summary>
/// Gets the endpoint URI of the MCP server.
/// </summary>
public virtual Uri ServerEndpoint { get; }

/// <summary>
/// Initializes a new instance of the <see cref="McpClient"/> class.
/// </summary>
/// <param name="endpoint">The URI endpoint of the MCP server.</param>
/// <param name="pipeline">Optional custom client pipeline. If not provided, a default pipeline will be created.</param>
public McpClient(Uri endpoint, ClientPipeline pipeline = null)
{
_pipeline = pipeline ?? ClientPipeline.Create();
_session = new McpSession(endpoint, _pipeline);
ServerEndpoint = endpoint;
}

/// <summary>
/// Starts the MCP client session by initializing the connection to the server.
/// </summary>
/// <returns>A task that represents the asynchronous operation.</returns>
public virtual async Task StartAsync()
{
await _session.EnsureInitializedAsync().ConfigureAwait(false);
}

/// <summary>
/// Lists all available tools from the MCP server.
/// </summary>
/// <returns>A task that represents the asynchronous operation. The task result contains the binary data representing the tools list.</returns>
/// <exception cref="InvalidOperationException">Thrown when the session is not initialized.</exception>
public virtual async Task<BinaryData> ListToolsAsync()
{
if (_session == null)
Expand All @@ -31,6 +51,13 @@ public virtual async Task<BinaryData> ListToolsAsync()
return await _session.SendMethod("tools/list").ConfigureAwait(false);
}

/// <summary>
/// Calls a specific tool on the MCP server.
/// </summary>
/// <param name="toolName">The name of the tool to call.</param>
/// <param name="parameters">The parameters to pass to the tool as binary data.</param>
/// <returns>A task that represents the asynchronous operation. The task result contains the binary data representing the tool's response.</returns>
/// <exception cref="InvalidOperationException">Thrown when the session is not initialized.</exception>
public virtual async Task<BinaryData> CallToolAsync(string toolName, BinaryData parameters)
{
if (_session == null)
Expand Down
4 changes: 2 additions & 2 deletions src/Utility/ResponseTools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public void AddLocalTool(MethodInfo function)
/// </summary>
/// <param name="client">The MCP client instance.</param>
/// <returns>A task representing the asynchronous operation.</returns>
internal async Task AddMcpToolsAsync(McpClient client)
public async Task AddMcpToolsAsync(McpClient client)
{
if (client == null) throw new ArgumentNullException(nameof(client));
_mcpClientsByEndpoint[client.ServerEndpoint.AbsoluteUri] = client;
Expand Down Expand Up @@ -257,7 +257,7 @@ internal string CallLocal(FunctionCallResponseItem call)
arguments.Add(argument.Value.ValueKind switch
{
JsonValueKind.String => argument.Value.GetString()!,
JsonValueKind.Number => argument.Value.GetDouble(),
JsonValueKind.Number => argument.Value.GetInt32(),
JsonValueKind.True => true,
JsonValueKind.False => false,
_ => throw new NotImplementedException()
Expand Down
197 changes: 197 additions & 0 deletions tests/Utility/ChatToolsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,201 @@ public void ThrowsWhenCallingNonExistentTool()

Assert.ThrowsAsync<InvalidOperationException>(() => tools.CallAsync(toolCalls));
}

[Test]
public async Task AddMcpToolsAsync_AddsToolsCorrectly()
{
// Arrange
var mcpEndpoint = new Uri("http://localhost:1234");
var mockMcpClient = new Mock<McpClient>(mcpEndpoint, null);
var tools = new ChatTools();

var mockToolsResponse = BinaryData.FromString(@"
{
""tools"": [
{
""name"": ""mcp-tool-1"",
""description"": ""This is the first MCP tool."",
""inputSchema"": {
""type"": ""object"",
""properties"": {
""param1"": {
""type"": ""string"",
""description"": ""The first param.""
},
""param2"": {
""type"": ""string"",
""description"": ""The second param.""
}
},
""required"": [""param1""]
}
},
{
""name"": ""mcp-tool-2"",
""description"": ""This is the second MCP tool."",
""inputSchema"": {
""type"": ""object"",
""properties"": {
""param1"": {
""type"": ""string"",
""description"": ""The first param.""
},
""param2"": {
""type"": ""string"",
""description"": ""The second param.""
}
},
""required"": []
}
}
]
}");

mockMcpClient.Setup(c => c.StartAsync())
.Returns(Task.CompletedTask);
mockMcpClient.Setup(c => c.ListToolsAsync())
.ReturnsAsync(mockToolsResponse);
mockMcpClient.Setup(c => c.CallToolAsync(It.IsAny<string>(), It.IsAny<BinaryData>()))
.ReturnsAsync(BinaryData.FromString("\"test result\""));
mockMcpClient.SetupGet(c => c.ServerEndpoint)
.Returns(mcpEndpoint);

// Act
await tools.AddMcpToolsAsync(mockMcpClient.Object);

// Assert
Assert.That(tools.Tools, Has.Count.EqualTo(2));
var toolNames = tools.Tools.Select(t => t.FunctionName).ToList();
Assert.That(toolNames, Contains.Item("localhost1234_-_mcp-tool-1"));
Assert.That(toolNames, Contains.Item("localhost1234_-_mcp-tool-2"));

// Verify we can call the tools
var toolCall = ChatToolCall.CreateFunctionToolCall("call1", "localhost1234_-_mcp-tool-1", BinaryData.FromString(@"{""param1"": ""test""}"));
var result = await tools.CallAsync(new[] { toolCall });
var resultsList = result.ToList();

Assert.That(resultsList, Has.Count.EqualTo(1));
Assert.That(resultsList[0].ToolCallId, Is.EqualTo("call1"));
Assert.That(resultsList[0].Content[0].Text, Is.EqualTo("\"test result\""));
}

[Test]
public async Task CreateCompletionOptions_WithMaxToolsParameter_FiltersTools()
{
// Arrange
var mcpEndpoint = new Uri("http://localhost:1234");
var mockMcpClient = new Mock<McpClient>(mcpEndpoint, null);
var tools = new ChatTools(mockEmbeddingClient.Object);

var mockToolsResponse = BinaryData.FromString(@"
{
""tools"": [
{
""name"": ""math-tool"",
""description"": ""Tool for performing mathematical calculations"",
""inputSchema"": {
""type"": ""object"",
""properties"": {
""expression"": {
""type"": ""string"",
""description"": ""The mathematical expression to evaluate""
}
}
}
},
{
""name"": ""weather-tool"",
""description"": ""Tool for getting weather information"",
""inputSchema"": {
""type"": ""object"",
""properties"": {
""location"": {
""type"": ""string"",
""description"": ""The location to get weather for""
}
}
}
},
{
""name"": ""translate-tool"",
""description"": ""Tool for translating text between languages"",
""inputSchema"": {
""type"": ""object"",
""properties"": {
""text"": {
""type"": ""string"",
""description"": ""Text to translate""
},
""targetLanguage"": {
""type"": ""string"",
""description"": ""Target language code""
}
}
}
}
]
}");

// Setup mock responses
var embeddings = new[]
{
OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.8f, 0.5f }),
OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.6f, 0.4f }),
OpenAIEmbeddingsModelFactory.OpenAIEmbedding(vector: new[] { 0.3f, 0.2f })
};
var embeddingCollection = OpenAIEmbeddingsModelFactory.OpenAIEmbeddingCollection(
items: embeddings,
model: "text-embedding-ada-002",
usage: OpenAIEmbeddingsModelFactory.EmbeddingTokenUsage(30, 30));
var mockResponse = new MockPipelineResponse(200);

mockMcpClient.Setup(c => c.StartAsync())
.Returns(Task.CompletedTask);
mockMcpClient.Setup(c => c.ListToolsAsync())
.ReturnsAsync(mockToolsResponse);
mockMcpClient.Setup(c => c.CallToolAsync("math-tool", It.IsAny<BinaryData>()))
.ReturnsAsync(BinaryData.FromString("\"math-tool result\""));
mockMcpClient.SetupGet(c => c.ServerEndpoint)
.Returns(mcpEndpoint);

mockEmbeddingClient
.Setup(c => c.GenerateEmbeddingAsync(
It.IsAny<string>(),
It.IsAny<EmbeddingGenerationOptions>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(ClientResult.FromValue(embeddings[0], mockResponse));

mockEmbeddingClient
.Setup(c => c.GenerateEmbeddingsAsync(
It.IsAny<IList<string>>(),
It.IsAny<EmbeddingGenerationOptions>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(ClientResult.FromValue(embeddingCollection, mockResponse));

// Add the tools
await tools.AddMcpToolsAsync(mockMcpClient.Object);

// Act & Assert
// Test with maxTools = 1
var options1 = await Task.Run(() => tools.CreateCompletionOptions("calculate 2+2", 1, 0.5f));
Assert.That(options1.Tools, Has.Count.EqualTo(1));

// Test with maxTools = 2
var options2 = await Task.Run(() => tools.CreateCompletionOptions("calculate 2+2", 2, 0.5f));
Assert.That(options2.Tools, Has.Count.EqualTo(2));

// Test that we can call the tools after filtering
var toolCall = ChatToolCall.CreateFunctionToolCall(
"call1",
"localhost1234_-_math-tool",
BinaryData.FromString(@"{""expression"": ""2+2""}"));
var result = await tools.CallAsync(new[] { toolCall });
Assert.That(result.First().ToolCallId, Is.EqualTo("call1"));
Assert.That(result.First().Content[0].Text, Is.EqualTo("\"math-tool result\""));

// Verify expected interactions
mockMcpClient.Verify(c => c.StartAsync(), Times.Once);
mockMcpClient.Verify(c => c.ListToolsAsync(), Times.Once);
}
}
Loading