diff --git a/client/client.go b/client/client.go index cda7665ef..0d47fcbf3 100644 --- a/client/client.go +++ b/client/client.go @@ -77,9 +77,16 @@ func (c *Client) Start(ctx context.Context) error { if c.transport == nil { return fmt.Errorf("transport is nil") } - err := c.transport.Start(ctx) - if err != nil { - return err + + if _, ok := c.transport.(*transport.Stdio); !ok { + // the stdio transport from NewStdioMCPClientWithOptions + // is already started, dont start again. + // + // Start the transport for other transport types + err := c.transport.Start(ctx) + if err != nil { + return err + } } c.transport.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { diff --git a/client/stdio.go b/client/stdio.go index 199ec14c3..6e6e1af99 100644 --- a/client/stdio.go +++ b/client/stdio.go @@ -12,7 +12,7 @@ import ( // It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. // Returns an error if the subprocess cannot be started or the pipes cannot be created. // -// NOTICE: NewStdioMCPClient will start the connection automatically. Don't call the Start method manually. +// NOTICE: NewStdioMCPClient will start the connection automatically. // This is for backward compatibility. func NewStdioMCPClient( command string, @@ -28,7 +28,6 @@ func NewStdioMCPClient( // such as setting a custom command function. // // NOTICE: NewStdioMCPClientWithOptions automatically starts the underlying transport. -// Don't call the Start method manually. // This is for backward compatibility. func NewStdioMCPClientWithOptions( command string, diff --git a/client/transport/sse.go b/client/transport/sse.go index 70a391905..305c93167 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -36,12 +36,12 @@ type SSE struct { headerFunc HTTPHeaderFunc logger util.Logger - started atomic.Bool - closed atomic.Bool - cancelSSEStream context.CancelFunc - protocolVersion atomic.Value // string - onConnectionLost func(error) - connectionLostMu sync.RWMutex + started atomic.Bool + closed atomic.Bool + cancelSSEStream context.CancelFunc + protocolVersion atomic.Value // string + onConnectionLost func(error) + connectionLostMu sync.RWMutex // OAuth support oauthHandler *OAuthHandler @@ -220,7 +220,7 @@ func (c *SSE) readSSE(reader io.ReadCloser) { c.connectionLostMu.RLock() handler := c.onConnectionLost c.connectionLostMu.RUnlock() - + if handler != nil { // This is not actually an error - HTTP2 idle timeout disconnection handler(err) diff --git a/client/transport/stdio.go b/client/transport/stdio.go index 488164c79..f3a95f4b0 100644 --- a/client/transport/stdio.go +++ b/client/transport/stdio.go @@ -27,7 +27,7 @@ type Stdio struct { cmd *exec.Cmd cmdFunc CommandFunc stdin io.WriteCloser - stdout *bufio.Reader + stdout *bufio.Scanner stderr io.ReadCloser responses map[string]chan *JSONRPCResponse mu sync.RWMutex @@ -72,7 +72,7 @@ func WithCommandLogger(logger util.Logger) StdioOption { func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio { return &Stdio{ stdin: output, - stdout: bufio.NewReader(input), + stdout: bufio.NewScanner(input), stderr: logging, responses: make(map[string]chan *JSONRPCResponse), @@ -180,7 +180,7 @@ func (c *Stdio) spawnCommand(ctx context.Context) error { c.cmd = cmd c.stdin = stdin c.stderr = stderr - c.stdout = bufio.NewReader(stdout) + c.stdout = bufio.NewScanner(stdout) if err := cmd.Start(); err != nil { return fmt.Errorf("failed to start command: %w", err) @@ -247,14 +247,15 @@ func (c *Stdio) readResponses() { case <-c.done: return default: - line, err := c.stdout.ReadString('\n') - if err != nil { - if err != io.EOF && !errors.Is(err, context.Canceled) { + if !c.stdout.Scan() { + err := c.stdout.Err() + if err != nil && !errors.Is(err, context.Canceled) { c.logger.Errorf("Error reading from stdout: %v", err) } return } + line := c.stdout.Text() // First try to parse as a generic message to check for ID field var baseMessage struct { JSONRPC string `json:"jsonrpc"` diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 268aeb342..40a5bc695 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -605,7 +605,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) { connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second) err := c.createGETConnectionToServer(connectCtx) cancel() - + if errors.Is(err, ErrGetMethodNotAllowed) { // server does not support listening c.logger.Errorf("server does not support listening") @@ -621,7 +621,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) { if err != nil { c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err) } - + // Use context-aware sleep select { case <-time.After(retryInterval): @@ -704,15 +704,15 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON // Create a new context with timeout for request handling, respecting parent context requestCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - + response, err := handler(requestCtx, request) if err != nil { c.logger.Errorf("error handling request %s: %v", request.Method, err) - + // Determine appropriate JSON-RPC error code based on error type var errorCode int var errorMessage string - + // Check for specific sampling-related errors if errors.Is(err, context.Canceled) { errorCode = -32800 // Request cancelled @@ -731,7 +731,7 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON errorMessage = err.Error() } } - + // Send error response errorResponse := &JSONRPCResponse{ JSONRPC: "2.0", diff --git a/client/transport/streamable_http_sampling_test.go b/client/transport/streamable_http_sampling_test.go index edba61eac..4a38f280e 100644 --- a/client/transport/streamable_http_sampling_test.go +++ b/client/transport/streamable_http_sampling_test.go @@ -16,27 +16,27 @@ import ( // TestStreamableHTTP_SamplingFlow tests the complete sampling flow with HTTP transport func TestStreamableHTTP_SamplingFlow(t *testing.T) { - // Create simple test server + // Create simple test server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Just respond OK to any requests w.WriteHeader(http.StatusOK) })) defer server.Close() - + // Create HTTP client transport client, err := NewStreamableHTTP(server.URL) if err != nil { t.Fatalf("Failed to create client: %v", err) } defer client.Close() - + // Set up sampling request handler var handledRequest *JSONRPCRequest handlerCalled := make(chan struct{}) client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { handledRequest = &request close(handlerCalled) - + // Simulate sampling handler response result := map[string]any{ "role": "assistant", @@ -47,25 +47,25 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { "model": "test-model", "stopReason": "stop_sequence", } - + resultBytes, _ := json.Marshal(result) - + return &JSONRPCResponse{ JSONRPC: "2.0", ID: request.ID, Result: resultBytes, }, nil }) - + // Start the client ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - + err = client.Start(ctx) if err != nil { t.Fatalf("Failed to start client: %v", err) } - + // Test direct request handling (simulating a sampling request) samplingRequest := JSONRPCRequest{ JSONRPC: "2.0", @@ -83,10 +83,10 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { }, }, } - + // Directly test request handling client.handleIncomingRequest(ctx, samplingRequest) - + // Wait for handler to be called select { case <-handlerCalled: @@ -94,12 +94,12 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { case <-time.After(1 * time.Second): t.Fatal("Handler was not called within timeout") } - + // Verify the request was handled if handledRequest == nil { t.Fatal("Sampling request was not handled") } - + if handledRequest.Method != string(mcp.MethodSamplingCreateMessage) { t.Errorf("Expected method %s, got %s", mcp.MethodSamplingCreateMessage, handledRequest.Method) } @@ -109,7 +109,7 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { var errorHandled sync.WaitGroup errorHandled.Add(1) - + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { var body map[string]any @@ -118,7 +118,7 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { w.WriteHeader(http.StatusOK) return } - + // Check if this is an error response if errorField, ok := body["error"]; ok { errorMap := errorField.(map[string]any) @@ -132,25 +132,25 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { w.WriteHeader(http.StatusOK) })) defer server.Close() - + client, err := NewStreamableHTTP(server.URL) if err != nil { t.Fatalf("Failed to create client: %v", err) } defer client.Close() - + // Set up request handler that returns an error client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { return nil, fmt.Errorf("sampling failed") }) - + // Start the client ctx := context.Background() err = client.Start(ctx) if err != nil { t.Fatalf("Failed to start client: %v", err) } - + // Simulate incoming sampling request samplingRequest := JSONRPCRequest{ JSONRPC: "2.0", @@ -158,10 +158,10 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { Method: string(mcp.MethodSamplingCreateMessage), Params: map[string]any{}, } - + // This should trigger error handling client.handleIncomingRequest(ctx, samplingRequest) - + // Wait for error to be handled errorHandled.Wait() } @@ -170,7 +170,7 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { var errorReceived bool errorReceivedChan := make(chan struct{}) - + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { var body map[string]any @@ -179,12 +179,12 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { w.WriteHeader(http.StatusOK) return } - + // Check if this is an error response with method not found if errorField, ok := body["error"]; ok { errorMap := errorField.(map[string]any) if code, ok := errorMap["code"].(float64); ok && code == -32601 { - if message, ok := errorMap["message"].(string); ok && + if message, ok := errorMap["message"].(string); ok && strings.Contains(message, "no handler configured") { errorReceived = true close(errorReceivedChan) @@ -195,21 +195,21 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { w.WriteHeader(http.StatusOK) })) defer server.Close() - + client, err := NewStreamableHTTP(server.URL) if err != nil { t.Fatalf("Failed to create client: %v", err) } defer client.Close() - + // Don't set any request handler - + ctx := context.Background() err = client.Start(ctx) if err != nil { t.Fatalf("Failed to start client: %v", err) } - + // Simulate incoming sampling request samplingRequest := JSONRPCRequest{ JSONRPC: "2.0", @@ -217,10 +217,10 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { Method: string(mcp.MethodSamplingCreateMessage), Params: map[string]any{}, } - + // This should trigger "method not found" error client.handleIncomingRequest(ctx, samplingRequest) - + // Wait for error to be received select { case <-errorReceivedChan: @@ -228,7 +228,7 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { case <-time.After(1 * time.Second): t.Fatal("Method not found error was not received within timeout") } - + if !errorReceived { t.Error("Expected method not found error, but didn't receive it") } @@ -241,13 +241,13 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { t.Fatalf("Failed to create client: %v", err) } defer client.Close() - + // Verify it implements BidirectionalInterface _, ok := any(client).(BidirectionalInterface) if !ok { t.Error("StreamableHTTP should implement BidirectionalInterface") } - + // Test SetRequestHandler handlerSet := false handlerSetChan := make(chan struct{}) @@ -256,7 +256,7 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { close(handlerSetChan) return nil, nil }) - + // Verify handler was set by triggering it ctx := context.Background() client.handleIncomingRequest(ctx, JSONRPCRequest{ @@ -264,7 +264,7 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { ID: mcp.NewRequestId(1), Method: "test", }) - + // Wait for handler to be called select { case <-handlerSetChan: @@ -272,7 +272,7 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { case <-time.After(1 * time.Second): t.Fatal("Handler was not called within timeout") } - + if !handlerSet { t.Error("Request handler was not properly set or called") } @@ -315,16 +315,16 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { // Track which requests have been received and their completion order var requestOrder []int var orderMutex sync.Mutex - + // Set up request handler that simulates different processing times client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { // Extract request ID to determine processing time requestIDValue := request.ID.Value() - + var delay time.Duration var responseText string var requestNum int - + // First request (ID 1) takes longer, second request (ID 2) completes faster if requestIDValue == int64(1) { delay = 100 * time.Millisecond @@ -341,7 +341,7 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { // Simulate processing time time.Sleep(delay) - + // Record completion order orderMutex.Lock() requestOrder = append(requestOrder, requestNum) @@ -428,7 +428,7 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { // Verify completion order: request 2 should complete first orderMutex.Lock() defer orderMutex.Unlock() - + if len(requestOrder) != 2 { t.Fatalf("Expected 2 completed requests, got %d", len(requestOrder)) } @@ -493,4 +493,4 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { } } } -} \ No newline at end of file +} diff --git a/examples/oauth_client/main.go b/examples/oauth_client/main.go index 639d8cb7a..27d3b6180 100644 --- a/examples/oauth_client/main.go +++ b/examples/oauth_client/main.go @@ -129,9 +129,11 @@ func maybeAuthorize(err error) { log.Fatalf("Failed to generate state: %v", err) } - err = oauthHandler.RegisterClient(context.Background(), "mcp-go-oauth-example") - if err != nil { - log.Fatalf("Failed to register client: %v", err) + if oauthHandler.GetClientID() == "" { + err = oauthHandler.RegisterClient(context.Background(), "mcp-go-oauth-example") + if err != nil { + log.Fatalf("Failed to register client: %v", err) + } } // Get the authorization URL diff --git a/examples/sampling_client/main.go b/examples/sampling_client/main.go index 093b59817..a655fde62 100644 --- a/examples/sampling_client/main.go +++ b/examples/sampling_client/main.go @@ -95,7 +95,7 @@ func main() { // Setup graceful shutdown sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - + // Create a context that cancels on signal ctx, cancel := context.WithCancel(ctx) go func() { @@ -103,7 +103,7 @@ func main() { log.Println("Received shutdown signal, closing client...") cancel() }() - + // Move defer after error checking defer func() { if err := mcpClient.Close(); err != nil { diff --git a/examples/sampling_http_client/main.go b/examples/sampling_http_client/main.go index 98817e6f8..946223f7b 100644 --- a/examples/sampling_http_client/main.go +++ b/examples/sampling_http_client/main.go @@ -63,7 +63,7 @@ func main() { log.Fatalf("Failed to create HTTP transport: %v", err) } defer httpTransport.Close() - + // Create client with sampling support mcpClient := client.NewClient( httpTransport, @@ -81,7 +81,7 @@ func main() { initRequest := mcp.InitializeRequest{ Params: mcp.InitializeParams{ ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, - Capabilities: mcp.ClientCapabilities{ + Capabilities: mcp.ClientCapabilities{ // Sampling capability will be automatically added by the client }, ClientInfo: mcp.Implementation{ @@ -90,7 +90,7 @@ func main() { }, }, } - + _, err = mcpClient.Initialize(ctx, initRequest) if err != nil { log.Fatalf("Failed to initialize MCP session: %v", err) @@ -102,7 +102,7 @@ func main() { // In a real application, you would keep the client running to handle sampling requests // For this example, we'll just demonstrate that it's working - + // Keep the client running (in a real app, you'd have your main application logic here) sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) @@ -113,4 +113,4 @@ func main() { case <-sigChan: log.Println("Received shutdown signal") } -} \ No newline at end of file +} diff --git a/examples/sampling_http_server/main.go b/examples/sampling_http_server/main.go index 95a2bf29b..a178accee 100644 --- a/examples/sampling_http_server/main.go +++ b/examples/sampling_http_server/main.go @@ -147,4 +147,4 @@ func main() { if err := httpServer.Start(":8080"); err != nil { log.Fatalf("Server failed to start: %v", err) } -} \ No newline at end of file +} diff --git a/examples/simple_client/main.go b/examples/simple_client/main.go index c0f48593a..d7497b6ca 100644 --- a/examples/simple_client/main.go +++ b/examples/simple_client/main.go @@ -121,18 +121,24 @@ func main() { serverInfo.ServerInfo.Version) fmt.Printf("Server capabilities: %+v\n", serverInfo.Capabilities) + // Perform health check using ping + fmt.Println("Performing health check...") + if err := c.Ping(ctx); err != nil { + log.Fatalf("Health check failed: %v", err) + } + fmt.Println("Server is alive and responding") + // List available tools if the server supports them if serverInfo.Capabilities.Tools != nil { fmt.Println("Fetching available tools...") toolsRequest := mcp.ListToolsRequest{} toolsResult, err := c.ListTools(ctx, toolsRequest) if err != nil { - log.Printf("Failed to list tools: %v", err) - } else { - fmt.Printf("Server has %d tools available\n", len(toolsResult.Tools)) - for i, tool := range toolsResult.Tools { - fmt.Printf(" %d. %s - %s\n", i+1, tool.Name, tool.Description) - } + log.Fatalf("Failed to list tools: %v", err) + } + fmt.Printf("Server has %d tools available\n", len(toolsResult.Tools)) + for i, tool := range toolsResult.Tools { + fmt.Printf(" %d. %s - %s\n", i+1, tool.Name, tool.Description) } } @@ -142,12 +148,11 @@ func main() { resourcesRequest := mcp.ListResourcesRequest{} resourcesResult, err := c.ListResources(ctx, resourcesRequest) if err != nil { - log.Printf("Failed to list resources: %v", err) - } else { - fmt.Printf("Server has %d resources available\n", len(resourcesResult.Resources)) - for i, resource := range resourcesResult.Resources { - fmt.Printf(" %d. %s - %s\n", i+1, resource.URI, resource.Name) - } + log.Fatalf("Failed to list resources: %v", err) + } + fmt.Printf("Server has %d resources available\n", len(resourcesResult.Resources)) + for i, resource := range resourcesResult.Resources { + fmt.Printf(" %d. %s - %s\n", i+1, resource.URI, resource.Name) } } diff --git a/mcp/tools.go b/mcp/tools.go index 3f3674923..493e8c778 100644 --- a/mcp/tools.go +++ b/mcp/tools.go @@ -565,6 +565,8 @@ type Tool struct { InputSchema ToolInputSchema `json:"inputSchema"` // Alternative to InputSchema - allows arbitrary JSON Schema to be provided RawInputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling + // A JSON Schema object defining the expected output returned by the tool . + OutputSchema ToolOutputSchema `json:"outputSchema,omitempty"` // Optional JSON Schema defining expected output structure RawOutputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling // Optional properties describing tool behavior @@ -601,7 +603,12 @@ func (t Tool) MarshalJSON() ([]byte, error) { // Add output schema if present if t.RawOutputSchema != nil { + if t.OutputSchema.Type != "" { + return nil, fmt.Errorf("tool %s has both OutputSchema and RawOutputSchema set: %w", t.Name, errToolSchemaConflict) + } m["outputSchema"] = t.RawOutputSchema + } else if t.OutputSchema.Type != "" { // If no output schema is specified, do not return anything + m["outputSchema"] = t.OutputSchema } m["annotations"] = t.Annotations @@ -609,15 +616,19 @@ func (t Tool) MarshalJSON() ([]byte, error) { return json.Marshal(m) } -type ToolInputSchema struct { +// ToolArgumentsSchema represents a JSON Schema for tool arguments. +type ToolArgumentsSchema struct { Defs map[string]any `json:"$defs,omitempty"` Type string `json:"type"` Properties map[string]any `json:"properties,omitempty"` Required []string `json:"required,omitempty"` } +type ToolInputSchema ToolArgumentsSchema // For retro-compatibility +type ToolOutputSchema ToolArgumentsSchema + // MarshalJSON implements the json.Marshaler interface for ToolInputSchema. -func (tis ToolInputSchema) MarshalJSON() ([]byte, error) { +func (tis ToolArgumentsSchema) MarshalJSON() ([]byte, error) { m := make(map[string]any) m["type"] = tis.Type @@ -780,7 +791,15 @@ func WithOutputSchema[T any]() ToolOption { return } - t.RawOutputSchema = json.RawMessage(mcpSchema) + // Retrieve the schema from raw JSON + if err := json.Unmarshal(mcpSchema, &t.OutputSchema); err != nil { + // Skip and maintain backward compatibility + return + } + + // Always set the type to "object" as of the current MCP spec + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema + t.OutputSchema.Type = "object" } } diff --git a/mcp/tools_test.go b/mcp/tools_test.go index 13c0f5643..4bb07aa5d 100644 --- a/mcp/tools_test.go +++ b/mcp/tools_test.go @@ -586,27 +586,51 @@ func TestToolWithOutputSchema(t *testing.T) { Email string `json:"email,omitempty" jsonschema_description:"Email address"` } - tool := NewTool("test_tool", - WithDescription("Test tool with output schema"), - WithOutputSchema[TestOutput](), - WithString("input", Required()), - ) - - // Check that RawOutputSchema was set - assert.NotNil(t, tool.RawOutputSchema) + tests := []struct { + name string + tool Tool + expectedOutputSchema bool + }{ + { + name: "default behavior", + tool: NewTool("test_tool", + WithDescription("Test tool with output schema"), + WithOutputSchema[TestOutput](), + WithString("input", Required()), + ), + expectedOutputSchema: true, + }, + { + name: "no output schema is set", + tool: NewTool("test_tool", + WithDescription("Test tool with no output schema"), + WithString("input", Required()), + ), + expectedOutputSchema: false, + }, + } - // Marshal and verify structure - data, err := json.Marshal(tool) - assert.NoError(t, err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Marshal and verify structure + data, err := json.Marshal(tt.tool) + assert.NoError(t, err) - var toolData map[string]any - err = json.Unmarshal(data, &toolData) - assert.NoError(t, err) + var toolData map[string]any + err = json.Unmarshal(data, &toolData) + assert.NoError(t, err) - // Verify outputSchema exists - outputSchema, exists := toolData["outputSchema"] - assert.True(t, exists) - assert.NotNil(t, outputSchema) + // Verify outputSchema exists + outputSchema, exists := toolData["outputSchema"] + if tt.expectedOutputSchema { + assert.True(t, exists) + assert.NotNil(t, outputSchema) + } else { + assert.False(t, exists) + assert.Nil(t, outputSchema) + } + }) + } } // TestNewToolResultStructured tests that the NewToolResultStructured function diff --git a/server/sampling.go b/server/sampling.go index 4423ccf5f..2118db155 100644 --- a/server/sampling.go +++ b/server/sampling.go @@ -12,7 +12,7 @@ import ( func (s *MCPServer) EnableSampling() { s.capabilitiesMu.Lock() defer s.capabilitiesMu.Unlock() - + enabled := true s.capabilities.sampling = &enabled } diff --git a/server/sampling_test.go b/server/sampling_test.go index fbecdd70d..012bf2fd9 100644 --- a/server/sampling_test.go +++ b/server/sampling_test.go @@ -116,7 +116,7 @@ func TestMCPServer_RequestSampling_Success(t *testing.T) { func TestMCPServer_EnableSampling_SetsCapability(t *testing.T) { server := NewMCPServer("test", "1.0.0") - + // Verify sampling capability is not set initially ctx := context.Background() initRequest := mcp.InitializeRequest{ @@ -129,25 +129,25 @@ func TestMCPServer_EnableSampling_SetsCapability(t *testing.T) { Capabilities: mcp.ClientCapabilities{}, }, } - + result, err := server.handleInitialize(ctx, 1, initRequest) if err != nil { t.Fatalf("unexpected error: %v", err) } - + if result.Capabilities.Sampling != nil { t.Error("sampling capability should not be set before EnableSampling() is called") } - + // Enable sampling server.EnableSampling() - + // Verify sampling capability is now set result, err = server.handleInitialize(ctx, 2, initRequest) if err != nil { t.Fatalf("unexpected error after EnableSampling(): %v", err) } - + if result.Capabilities.Sampling == nil { t.Error("sampling capability should be set after EnableSampling() is called") } diff --git a/server/server.go b/server/server.go index 366bf6611..688357280 100644 --- a/server/server.go +++ b/server/server.go @@ -43,6 +43,9 @@ type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mc // ToolHandlerMiddleware is a middleware function that wraps a ToolHandlerFunc. type ToolHandlerMiddleware func(ToolHandlerFunc) ToolHandlerFunc +// ResourceHandlerMiddleware is a middleware function that wraps a ResourceHandlerFunc. +type ResourceHandlerMiddleware func(ResourceHandlerFunc) ResourceHandlerFunc + // ToolFilterFunc is a function that filters tools based on context, typically using session information. type ToolFilterFunc func(ctx context.Context, tools []mcp.Tool) []mcp.Tool @@ -151,21 +154,22 @@ type MCPServer struct { capabilitiesMu sync.RWMutex toolFiltersMu sync.RWMutex - name string - version string - instructions string - resources map[string]resourceEntry - resourceTemplates map[string]resourceTemplateEntry - prompts map[string]mcp.Prompt - promptHandlers map[string]PromptHandlerFunc - tools map[string]ServerTool - toolHandlerMiddlewares []ToolHandlerMiddleware - toolFilters []ToolFilterFunc - notificationHandlers map[string]NotificationHandlerFunc - capabilities serverCapabilities - paginationLimit *int - sessions sync.Map - hooks *Hooks + name string + version string + instructions string + resources map[string]resourceEntry + resourceTemplates map[string]resourceTemplateEntry + prompts map[string]mcp.Prompt + promptHandlers map[string]PromptHandlerFunc + tools map[string]ServerTool + toolHandlerMiddlewares []ToolHandlerMiddleware + resourceHandlerMiddlewares []ResourceHandlerMiddleware + toolFilters []ToolFilterFunc + notificationHandlers map[string]NotificationHandlerFunc + capabilities serverCapabilities + paginationLimit *int + sessions sync.Map + hooks *Hooks } // WithPaginationLimit sets the pagination limit for the server. @@ -223,6 +227,36 @@ func WithToolHandlerMiddleware( } } +// WithResourceHandlerMiddleware allows adding a middleware for the +// resource handler call chain. +func WithResourceHandlerMiddleware( + resourceHandlerMiddleware ResourceHandlerMiddleware, +) ServerOption { + return func(s *MCPServer) { + s.middlewareMu.Lock() + s.resourceHandlerMiddlewares = append(s.resourceHandlerMiddlewares, resourceHandlerMiddleware) + s.middlewareMu.Unlock() + } +} + +// WithResourceRecovery adds a middleware that recovers from panics in resource handlers. +func WithResourceRecovery() ServerOption { + return WithResourceHandlerMiddleware(func(next ResourceHandlerFunc) ResourceHandlerFunc { + return func(ctx context.Context, request mcp.ReadResourceRequest) (result []mcp.ResourceContents, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf( + "panic recovered in %s resource handler: %v", + request.Params.URI, + r, + ) + } + }() + return next(ctx, request) + } + }) +} + // WithToolFilter adds a filter function that will be applied to tools before they are returned in list_tools func WithToolFilter( toolFilter ToolFilterFunc, @@ -301,14 +335,16 @@ func NewMCPServer( opts ...ServerOption, ) *MCPServer { s := &MCPServer{ - resources: make(map[string]resourceEntry), - resourceTemplates: make(map[string]resourceTemplateEntry), - prompts: make(map[string]mcp.Prompt), - promptHandlers: make(map[string]PromptHandlerFunc), - tools: make(map[string]ServerTool), - name: name, - version: version, - notificationHandlers: make(map[string]NotificationHandlerFunc), + resources: make(map[string]resourceEntry), + resourceTemplates: make(map[string]resourceTemplateEntry), + prompts: make(map[string]mcp.Prompt), + promptHandlers: make(map[string]PromptHandlerFunc), + tools: make(map[string]ServerTool), + toolHandlerMiddlewares: make([]ToolHandlerMiddleware, 0), + resourceHandlerMiddlewares: make([]ResourceHandlerMiddleware, 0), + name: name, + version: version, + notificationHandlers: make(map[string]NotificationHandlerFunc), capabilities: serverCapabilities{ tools: nil, resources: nil, @@ -838,7 +874,17 @@ func (s *MCPServer) handleReadResource( if entry, ok := s.resources[request.Params.URI]; ok { handler := entry.handler s.resourcesMu.RUnlock() - contents, err := handler(ctx, request) + + finalHandler := handler + s.middlewareMu.RLock() + mw := s.resourceHandlerMiddlewares + // Apply middlewares in reverse order + for i := len(mw) - 1; i >= 0; i-- { + finalHandler = mw[i](finalHandler) + } + s.middlewareMu.RUnlock() + + contents, err := finalHandler(ctx, request) if err != nil { return nil, &requestError{ id: id, diff --git a/server/streamable_http.go b/server/streamable_http.go index 24ec1c95a..c97d9b747 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -8,6 +8,7 @@ import ( "mime" "net/http" "net/http/httptest" + "os" "strings" "sync" "sync/atomic" @@ -93,6 +94,15 @@ func WithLogger(logger util.Logger) StreamableHTTPOption { } } +// WithTLSCert sets the TLS certificate and key files for HTTPS support. +// Both certFile and keyFile must be provided to enable TLS. +func WithTLSCert(certFile, keyFile string) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.tlsCertFile = certFile + s.tlsKeyFile = keyFile + } +} + // StreamableHTTPServer implements a Streamable-http based MCP server. // It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams. // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http @@ -131,6 +141,9 @@ type StreamableHTTPServer struct { listenHeartbeatInterval time.Duration logger util.Logger sessionLogLevels *sessionLogLevelsStore + + tlsCertFile string + tlsKeyFile string } // NewStreamableHTTPServer creates a new streamable-http server instance @@ -188,6 +201,19 @@ func (s *StreamableHTTPServer) Start(addr string) error { srv := s.httpServer s.mu.Unlock() + if s.tlsCertFile != "" || s.tlsKeyFile != "" { + if s.tlsCertFile == "" || s.tlsKeyFile == "" { + return fmt.Errorf("both TLS cert and key must be provided") + } + if _, err := os.Stat(s.tlsCertFile); err != nil { + return fmt.Errorf("failed to find TLS certificate file: %w", err) + } + if _, err := os.Stat(s.tlsKeyFile); err != nil { + return fmt.Errorf("failed to find TLS key file: %w", err) + } + return srv.ListenAndServeTLS(s.tlsCertFile, s.tlsKeyFile) + } + return srv.ListenAndServe() } @@ -237,9 +263,9 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } // Check if this is a sampling response (has result/error but no method) - isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && + isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && (jsonMessage.Result != nil || jsonMessage.Error != nil) - + isInitializeRequest := jsonMessage.Method == mcp.MethodInitialize // Handle sampling responses separately @@ -390,7 +416,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) return } defer s.server.UnregisterSession(r.Context(), sessionID) - + // Register session for sampling response delivery s.activeSessions.Store(sessionID, session) defer s.activeSessions.Delete(sessionID) @@ -743,18 +769,18 @@ type streamableHttpSession struct { logLevels *sessionLogLevelsStore // Sampling support for bidirectional communication - samplingRequestChan chan samplingRequestItem // server -> client sampling requests - samplingRequests sync.Map // requestID -> pending sampling request context - requestIDCounter atomic.Int64 // for generating unique request IDs + samplingRequestChan chan samplingRequestItem // server -> client sampling requests + samplingRequests sync.Map // requestID -> pending sampling request context + requestIDCounter atomic.Int64 // for generating unique request IDs } func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession { s := &streamableHttpSession{ - sessionID: sessionID, - notificationChannel: make(chan mcp.JSONRPCNotification, 100), - tools: toolStore, - logLevels: levels, - samplingRequestChan: make(chan samplingRequestItem, 10), + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + tools: toolStore, + logLevels: levels, + samplingRequestChan: make(chan samplingRequestItem, 10), } return s } @@ -810,21 +836,21 @@ var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil) func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { // Generate unique request ID requestID := s.requestIDCounter.Add(1) - + // Create response channel for this specific request responseChan := make(chan samplingResponseItem, 1) - + // Create the sampling request item samplingRequest := samplingRequestItem{ requestID: requestID, request: request, response: responseChan, } - + // Store the pending request s.samplingRequests.Store(requestID, responseChan) defer s.samplingRequests.Delete(requestID) - + // Send the sampling request via the channel (non-blocking) select { case s.samplingRequestChan <- samplingRequest: @@ -834,7 +860,7 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp default: return nil, fmt.Errorf("sampling request queue is full - server overloaded") } - + // Wait for response or context cancellation select { case response := <-responseChan: diff --git a/server/streamable_http_sampling_test.go b/server/streamable_http_sampling_test.go index 4cf57838c..50be27fa7 100644 --- a/server/streamable_http_sampling_test.go +++ b/server/streamable_http_sampling_test.go @@ -213,4 +213,4 @@ func TestStreamableHTTPServer_SamplingQueueFull(t *testing.T) { if !strings.Contains(err.Error(), "queue is full") { t.Errorf("Expected queue full error, got: %v", err) } -} \ No newline at end of file +} diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 105fd18ce..b464e1bdd 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -894,6 +894,26 @@ func TestStreamableHTTP_HeaderPassthrough(t *testing.T) { } } +func TestStreamableHTTPServer_TLS(t *testing.T) { + t.Run("TLS options are set correctly", func(t *testing.T) { + mcpServer := NewMCPServer("test-mcp-server", "1.0.0") + certFile := "/path/to/cert.pem" + keyFile := "/path/to/key.pem" + + server := NewStreamableHTTPServer( + mcpServer, + WithTLSCert(certFile, keyFile), + ) + + if server.tlsCertFile != certFile { + t.Errorf("Expected tlsCertFile to be %s, got %s", certFile, server.tlsCertFile) + } + if server.tlsKeyFile != keyFile { + t.Errorf("Expected tlsKeyFile to be %s, got %s", keyFile, server.tlsKeyFile) + } + }) +} + func postJSON(url string, bodyObject any) (*http.Response, error) { jsonBody, _ := json.Marshal(bodyObject) req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody)) diff --git a/www/docs/pages/servers/basics.mdx b/www/docs/pages/servers/basics.mdx index 7b33f33ce..b83bbdfab 100644 --- a/www/docs/pages/servers/basics.mdx +++ b/www/docs/pages/servers/basics.mdx @@ -182,6 +182,7 @@ Configure transport-specific options: httpServer := server.NewStreamableHTTPServer(s, server.WithEndpointPath("/mcp"), server.WithStateless(true), + server.WithTLSCert("/path/to/cert.pem", "/path/to/key.pem"), ) if err := httpServer.Start(":8080"); err != nil {