From 4e353ac80e0e18a017988b43c3b1ffd3be1fb732 Mon Sep 17 00:00:00 2001 From: Izzy Date: Sun, 24 Aug 2025 15:05:13 +0800 Subject: [PATCH 1/9] fix: use bufio.Scanner for stdio transport to avoid panic when stdio mcp server outputs a long line (#464) Change-Id: Iaaaf44f80d2e49f5275c5f6903c87dcb4dbb53a3 Co-authored-by: tangyuyi --- client/transport/stdio.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/client/transport/stdio.go b/client/transport/stdio.go index 488164c7..f3a95f4b 100644 --- a/client/transport/stdio.go +++ b/client/transport/stdio.go @@ -27,7 +27,7 @@ type Stdio struct { cmd *exec.Cmd cmdFunc CommandFunc stdin io.WriteCloser - stdout *bufio.Reader + stdout *bufio.Scanner stderr io.ReadCloser responses map[string]chan *JSONRPCResponse mu sync.RWMutex @@ -72,7 +72,7 @@ func WithCommandLogger(logger util.Logger) StdioOption { func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio { return &Stdio{ stdin: output, - stdout: bufio.NewReader(input), + stdout: bufio.NewScanner(input), stderr: logging, responses: make(map[string]chan *JSONRPCResponse), @@ -180,7 +180,7 @@ func (c *Stdio) spawnCommand(ctx context.Context) error { c.cmd = cmd c.stdin = stdin c.stderr = stderr - c.stdout = bufio.NewReader(stdout) + c.stdout = bufio.NewScanner(stdout) if err := cmd.Start(); err != nil { return fmt.Errorf("failed to start command: %w", err) @@ -247,14 +247,15 @@ func (c *Stdio) readResponses() { case <-c.done: return default: - line, err := c.stdout.ReadString('\n') - if err != nil { - if err != io.EOF && !errors.Is(err, context.Canceled) { + if !c.stdout.Scan() { + err := c.stdout.Err() + if err != nil && !errors.Is(err, context.Canceled) { c.logger.Errorf("Error reading from stdout: %v", err) } return } + line := c.stdout.Text() // First try to parse as a generic message to check for ID field var baseMessage struct { JSONRPC string `json:"jsonrpc"` From b9243915a8c31a0865d52a6df70c05bffe9ac170 Mon Sep 17 00:00:00 2001 From: Alejandro Borbolla <52978371+alex210501@users.noreply.github.com> Date: Tue, 2 Sep 2025 09:44:32 +0200 Subject: [PATCH 2/9] fix(tool): Return the `OutputSchema` from the tool definition (#571) This commit returns the `OutputSchema` from the tool definition as per the MCP spec: https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema The behaviour of `WithOutputSchema()` also changed as `RawOutputSchema` is not longer populated, but `OutputSchema` is from the `T` generic type. The only way now to set `RawOutputSchema` is through the `WithRawOutputSchema()` method. --- mcp/tools.go | 25 ++++++++++++++++++++++--- mcp/tools_test.go | 2 +- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/mcp/tools.go b/mcp/tools.go index 3f367492..185aefa6 100644 --- a/mcp/tools.go +++ b/mcp/tools.go @@ -565,6 +565,8 @@ type Tool struct { InputSchema ToolInputSchema `json:"inputSchema"` // Alternative to InputSchema - allows arbitrary JSON Schema to be provided RawInputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling + // A JSON Schema object defining the expected output returned by the tool . + OutputSchema ToolOutputSchema `json:"outputSchema,omitempty"` // Optional JSON Schema defining expected output structure RawOutputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling // Optional properties describing tool behavior @@ -601,7 +603,12 @@ func (t Tool) MarshalJSON() ([]byte, error) { // Add output schema if present if t.RawOutputSchema != nil { + if t.OutputSchema.Type != "" { + return nil, fmt.Errorf("tool %s has both OutputSchema and RawOutputSchema set: %w", t.Name, errToolSchemaConflict) + } m["outputSchema"] = t.RawOutputSchema + } else { + m["outputSchema"] = t.OutputSchema } m["annotations"] = t.Annotations @@ -609,15 +616,19 @@ func (t Tool) MarshalJSON() ([]byte, error) { return json.Marshal(m) } -type ToolInputSchema struct { +// ToolArgumentsSchema represents a JSON Schema for tool arguments. +type ToolArgumentsSchema struct { Defs map[string]any `json:"$defs,omitempty"` Type string `json:"type"` Properties map[string]any `json:"properties,omitempty"` Required []string `json:"required,omitempty"` } +type ToolInputSchema ToolArgumentsSchema // For retro-compatibility +type ToolOutputSchema ToolArgumentsSchema + // MarshalJSON implements the json.Marshaler interface for ToolInputSchema. -func (tis ToolInputSchema) MarshalJSON() ([]byte, error) { +func (tis ToolArgumentsSchema) MarshalJSON() ([]byte, error) { m := make(map[string]any) m["type"] = tis.Type @@ -780,7 +791,15 @@ func WithOutputSchema[T any]() ToolOption { return } - t.RawOutputSchema = json.RawMessage(mcpSchema) + // Retrieve the schema from raw JSON + if err := json.Unmarshal(mcpSchema, &t.OutputSchema); err != nil { + // Skip and maintain backward compatibility + return + } + + // Always set the type to "object" as of the current MCP spec + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema + t.OutputSchema.Type = "object" } } diff --git a/mcp/tools_test.go b/mcp/tools_test.go index 13c0f564..270bf64b 100644 --- a/mcp/tools_test.go +++ b/mcp/tools_test.go @@ -593,7 +593,7 @@ func TestToolWithOutputSchema(t *testing.T) { ) // Check that RawOutputSchema was set - assert.NotNil(t, tool.RawOutputSchema) + assert.NotNil(t, tool.OutputSchema) // Marshal and verify structure data, err := json.Marshal(tool) From ef80a50bdce2355c0a70e14d3f95bde90f5887d8 Mon Sep 17 00:00:00 2001 From: TJ Hoplock <33664289+tjhop@users.noreply.github.com> Date: Tue, 2 Sep 2025 03:46:03 -0400 Subject: [PATCH 3/9] feat: add resource handler middleware capability (#569) * feat: add resource handler middleware capability This is a fairly direct adaption of the existing tool handler middleware to also allow support for resource middlewares. Use case: I'm working on an MCP server that manages an API client that is used for both tool and resource calls. The tool handler middleware provides a nice pattern to wrap tool calls that fits some use cases better than the before/after tool call hooks. It would be helpful to have first party support for this pattern in the library so I don't need to work around it with custom closures etc. Notes: - There are currently (that I can find) that exercise the existing tool handler middleware logic, so I did not add tests for the resource handler middleware logic. - Existing docs, specifically those for the streamable HTTP transport, reference some middleware functions (for both tools and resources) that don't exist (ex: `s.AddToolMiddleware` does not, `s.AddResourceMiddleware` does not exist, etc). It seems they may be out of date. Happy to discuss updates to docs in a separate PR. Signed-off-by: TJ Hoplock * feat: add `WithResourceRecovery()` ServerOption The existing `WithRecovery()` ServerOption is tool oriented, this adds a corresponding recovery handler for resources. This will be especially useful if Resource middlewares are used, where things may possibly/need to panic. Signed-off-by: TJ Hoplock --------- Signed-off-by: TJ Hoplock --- server/server.go | 94 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 70 insertions(+), 24 deletions(-) diff --git a/server/server.go b/server/server.go index 366bf661..68835728 100644 --- a/server/server.go +++ b/server/server.go @@ -43,6 +43,9 @@ type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mc // ToolHandlerMiddleware is a middleware function that wraps a ToolHandlerFunc. type ToolHandlerMiddleware func(ToolHandlerFunc) ToolHandlerFunc +// ResourceHandlerMiddleware is a middleware function that wraps a ResourceHandlerFunc. +type ResourceHandlerMiddleware func(ResourceHandlerFunc) ResourceHandlerFunc + // ToolFilterFunc is a function that filters tools based on context, typically using session information. type ToolFilterFunc func(ctx context.Context, tools []mcp.Tool) []mcp.Tool @@ -151,21 +154,22 @@ type MCPServer struct { capabilitiesMu sync.RWMutex toolFiltersMu sync.RWMutex - name string - version string - instructions string - resources map[string]resourceEntry - resourceTemplates map[string]resourceTemplateEntry - prompts map[string]mcp.Prompt - promptHandlers map[string]PromptHandlerFunc - tools map[string]ServerTool - toolHandlerMiddlewares []ToolHandlerMiddleware - toolFilters []ToolFilterFunc - notificationHandlers map[string]NotificationHandlerFunc - capabilities serverCapabilities - paginationLimit *int - sessions sync.Map - hooks *Hooks + name string + version string + instructions string + resources map[string]resourceEntry + resourceTemplates map[string]resourceTemplateEntry + prompts map[string]mcp.Prompt + promptHandlers map[string]PromptHandlerFunc + tools map[string]ServerTool + toolHandlerMiddlewares []ToolHandlerMiddleware + resourceHandlerMiddlewares []ResourceHandlerMiddleware + toolFilters []ToolFilterFunc + notificationHandlers map[string]NotificationHandlerFunc + capabilities serverCapabilities + paginationLimit *int + sessions sync.Map + hooks *Hooks } // WithPaginationLimit sets the pagination limit for the server. @@ -223,6 +227,36 @@ func WithToolHandlerMiddleware( } } +// WithResourceHandlerMiddleware allows adding a middleware for the +// resource handler call chain. +func WithResourceHandlerMiddleware( + resourceHandlerMiddleware ResourceHandlerMiddleware, +) ServerOption { + return func(s *MCPServer) { + s.middlewareMu.Lock() + s.resourceHandlerMiddlewares = append(s.resourceHandlerMiddlewares, resourceHandlerMiddleware) + s.middlewareMu.Unlock() + } +} + +// WithResourceRecovery adds a middleware that recovers from panics in resource handlers. +func WithResourceRecovery() ServerOption { + return WithResourceHandlerMiddleware(func(next ResourceHandlerFunc) ResourceHandlerFunc { + return func(ctx context.Context, request mcp.ReadResourceRequest) (result []mcp.ResourceContents, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf( + "panic recovered in %s resource handler: %v", + request.Params.URI, + r, + ) + } + }() + return next(ctx, request) + } + }) +} + // WithToolFilter adds a filter function that will be applied to tools before they are returned in list_tools func WithToolFilter( toolFilter ToolFilterFunc, @@ -301,14 +335,16 @@ func NewMCPServer( opts ...ServerOption, ) *MCPServer { s := &MCPServer{ - resources: make(map[string]resourceEntry), - resourceTemplates: make(map[string]resourceTemplateEntry), - prompts: make(map[string]mcp.Prompt), - promptHandlers: make(map[string]PromptHandlerFunc), - tools: make(map[string]ServerTool), - name: name, - version: version, - notificationHandlers: make(map[string]NotificationHandlerFunc), + resources: make(map[string]resourceEntry), + resourceTemplates: make(map[string]resourceTemplateEntry), + prompts: make(map[string]mcp.Prompt), + promptHandlers: make(map[string]PromptHandlerFunc), + tools: make(map[string]ServerTool), + toolHandlerMiddlewares: make([]ToolHandlerMiddleware, 0), + resourceHandlerMiddlewares: make([]ResourceHandlerMiddleware, 0), + name: name, + version: version, + notificationHandlers: make(map[string]NotificationHandlerFunc), capabilities: serverCapabilities{ tools: nil, resources: nil, @@ -838,7 +874,17 @@ func (s *MCPServer) handleReadResource( if entry, ok := s.resources[request.Params.URI]; ok { handler := entry.handler s.resourcesMu.RUnlock() - contents, err := handler(ctx, request) + + finalHandler := handler + s.middlewareMu.RLock() + mw := s.resourceHandlerMiddlewares + // Apply middlewares in reverse order + for i := len(mw) - 1; i >= 0; i-- { + finalHandler = mw[i](finalHandler) + } + s.middlewareMu.RUnlock() + + contents, err := finalHandler(ctx, request) if err != nil { return nil, &requestError{ id: id, From 40ce109d589965a99cc6b8dfac6f8ba6e7b1b708 Mon Sep 17 00:00:00 2001 From: opencow Date: Tue, 2 Sep 2025 03:46:59 -0400 Subject: [PATCH 4/9] feat: add tls support for streamable-http (#568) * tests tls * log * fail * clean * docs * nit * nit --- server/streamable_http.go | 58 ++++++++++++++++++++++--------- server/streamable_http_test.go | 20 +++++++++++ www/docs/pages/servers/basics.mdx | 1 + 3 files changed, 63 insertions(+), 16 deletions(-) diff --git a/server/streamable_http.go b/server/streamable_http.go index 24ec1c95..c97d9b74 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -8,6 +8,7 @@ import ( "mime" "net/http" "net/http/httptest" + "os" "strings" "sync" "sync/atomic" @@ -93,6 +94,15 @@ func WithLogger(logger util.Logger) StreamableHTTPOption { } } +// WithTLSCert sets the TLS certificate and key files for HTTPS support. +// Both certFile and keyFile must be provided to enable TLS. +func WithTLSCert(certFile, keyFile string) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.tlsCertFile = certFile + s.tlsKeyFile = keyFile + } +} + // StreamableHTTPServer implements a Streamable-http based MCP server. // It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams. // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http @@ -131,6 +141,9 @@ type StreamableHTTPServer struct { listenHeartbeatInterval time.Duration logger util.Logger sessionLogLevels *sessionLogLevelsStore + + tlsCertFile string + tlsKeyFile string } // NewStreamableHTTPServer creates a new streamable-http server instance @@ -188,6 +201,19 @@ func (s *StreamableHTTPServer) Start(addr string) error { srv := s.httpServer s.mu.Unlock() + if s.tlsCertFile != "" || s.tlsKeyFile != "" { + if s.tlsCertFile == "" || s.tlsKeyFile == "" { + return fmt.Errorf("both TLS cert and key must be provided") + } + if _, err := os.Stat(s.tlsCertFile); err != nil { + return fmt.Errorf("failed to find TLS certificate file: %w", err) + } + if _, err := os.Stat(s.tlsKeyFile); err != nil { + return fmt.Errorf("failed to find TLS key file: %w", err) + } + return srv.ListenAndServeTLS(s.tlsCertFile, s.tlsKeyFile) + } + return srv.ListenAndServe() } @@ -237,9 +263,9 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } // Check if this is a sampling response (has result/error but no method) - isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && + isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && (jsonMessage.Result != nil || jsonMessage.Error != nil) - + isInitializeRequest := jsonMessage.Method == mcp.MethodInitialize // Handle sampling responses separately @@ -390,7 +416,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) return } defer s.server.UnregisterSession(r.Context(), sessionID) - + // Register session for sampling response delivery s.activeSessions.Store(sessionID, session) defer s.activeSessions.Delete(sessionID) @@ -743,18 +769,18 @@ type streamableHttpSession struct { logLevels *sessionLogLevelsStore // Sampling support for bidirectional communication - samplingRequestChan chan samplingRequestItem // server -> client sampling requests - samplingRequests sync.Map // requestID -> pending sampling request context - requestIDCounter atomic.Int64 // for generating unique request IDs + samplingRequestChan chan samplingRequestItem // server -> client sampling requests + samplingRequests sync.Map // requestID -> pending sampling request context + requestIDCounter atomic.Int64 // for generating unique request IDs } func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession { s := &streamableHttpSession{ - sessionID: sessionID, - notificationChannel: make(chan mcp.JSONRPCNotification, 100), - tools: toolStore, - logLevels: levels, - samplingRequestChan: make(chan samplingRequestItem, 10), + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + tools: toolStore, + logLevels: levels, + samplingRequestChan: make(chan samplingRequestItem, 10), } return s } @@ -810,21 +836,21 @@ var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil) func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { // Generate unique request ID requestID := s.requestIDCounter.Add(1) - + // Create response channel for this specific request responseChan := make(chan samplingResponseItem, 1) - + // Create the sampling request item samplingRequest := samplingRequestItem{ requestID: requestID, request: request, response: responseChan, } - + // Store the pending request s.samplingRequests.Store(requestID, responseChan) defer s.samplingRequests.Delete(requestID) - + // Send the sampling request via the channel (non-blocking) select { case s.samplingRequestChan <- samplingRequest: @@ -834,7 +860,7 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp default: return nil, fmt.Errorf("sampling request queue is full - server overloaded") } - + // Wait for response or context cancellation select { case response := <-responseChan: diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 105fd18c..b464e1bd 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -894,6 +894,26 @@ func TestStreamableHTTP_HeaderPassthrough(t *testing.T) { } } +func TestStreamableHTTPServer_TLS(t *testing.T) { + t.Run("TLS options are set correctly", func(t *testing.T) { + mcpServer := NewMCPServer("test-mcp-server", "1.0.0") + certFile := "/path/to/cert.pem" + keyFile := "/path/to/key.pem" + + server := NewStreamableHTTPServer( + mcpServer, + WithTLSCert(certFile, keyFile), + ) + + if server.tlsCertFile != certFile { + t.Errorf("Expected tlsCertFile to be %s, got %s", certFile, server.tlsCertFile) + } + if server.tlsKeyFile != keyFile { + t.Errorf("Expected tlsKeyFile to be %s, got %s", keyFile, server.tlsKeyFile) + } + }) +} + func postJSON(url string, bodyObject any) (*http.Response, error) { jsonBody, _ := json.Marshal(bodyObject) req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody)) diff --git a/www/docs/pages/servers/basics.mdx b/www/docs/pages/servers/basics.mdx index 7b33f33c..b83bbdfa 100644 --- a/www/docs/pages/servers/basics.mdx +++ b/www/docs/pages/servers/basics.mdx @@ -182,6 +182,7 @@ Configure transport-specific options: httpServer := server.NewStreamableHTTPServer(s, server.WithEndpointPath("/mcp"), server.WithStateless(true), + server.WithTLSCert("/path/to/cert.pem", "/path/to/key.pem"), ) if err := httpServer.Start(":8080"); err != nil { From d2f81b67b850666960c7a7cbc23b5db9556b2a5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E7=8E=AE=20=28Jade=20Lin=29?= Date: Tue, 2 Sep 2025 15:49:36 +0800 Subject: [PATCH 5/9] fix: prevent double-starting stdio transport in client (#564) * fix: prevent double-starting stdio transport in client Avoid starting stdio transport twice by checking transport type before calling Start(). The stdio transport from NewStdioMCPClientWithOptions is already started and doesn't need to be started again. * docs: update NewStdioMCPClient comment for clarity --- client/client.go | 13 ++++++++++--- client/stdio.go | 3 +-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/client/client.go b/client/client.go index cda7665e..0d47fcbf 100644 --- a/client/client.go +++ b/client/client.go @@ -77,9 +77,16 @@ func (c *Client) Start(ctx context.Context) error { if c.transport == nil { return fmt.Errorf("transport is nil") } - err := c.transport.Start(ctx) - if err != nil { - return err + + if _, ok := c.transport.(*transport.Stdio); !ok { + // the stdio transport from NewStdioMCPClientWithOptions + // is already started, dont start again. + // + // Start the transport for other transport types + err := c.transport.Start(ctx) + if err != nil { + return err + } } c.transport.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { diff --git a/client/stdio.go b/client/stdio.go index 199ec14c..6e6e1af9 100644 --- a/client/stdio.go +++ b/client/stdio.go @@ -12,7 +12,7 @@ import ( // It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. // Returns an error if the subprocess cannot be started or the pipes cannot be created. // -// NOTICE: NewStdioMCPClient will start the connection automatically. Don't call the Start method manually. +// NOTICE: NewStdioMCPClient will start the connection automatically. // This is for backward compatibility. func NewStdioMCPClient( command string, @@ -28,7 +28,6 @@ func NewStdioMCPClient( // such as setting a custom command function. // // NOTICE: NewStdioMCPClientWithOptions automatically starts the underlying transport. -// Don't call the Start method manually. // This is for backward compatibility. func NewStdioMCPClientWithOptions( command string, From 004ca9e0341eb1a9702711e2ca5250527b330e7f Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Tue, 2 Sep 2025 15:51:57 +0800 Subject: [PATCH 6/9] docs(client): improve server reliability and error handling (#560) * docs(client): improve server reliability and error handling - Add a health check step using Ping to verify server availability - Change error handling for listing tools and resources to terminate the program on failure Signed-off-by: Bo-Yi Wu * style: remove emojis for neutral health check log output - Remove emojis from health check log and success messages for a more neutral output Signed-off-by: appleboy --------- Signed-off-by: Bo-Yi Wu Signed-off-by: appleboy --- examples/simple_client/main.go | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/examples/simple_client/main.go b/examples/simple_client/main.go index c0f48593..d7497b6c 100644 --- a/examples/simple_client/main.go +++ b/examples/simple_client/main.go @@ -121,18 +121,24 @@ func main() { serverInfo.ServerInfo.Version) fmt.Printf("Server capabilities: %+v\n", serverInfo.Capabilities) + // Perform health check using ping + fmt.Println("Performing health check...") + if err := c.Ping(ctx); err != nil { + log.Fatalf("Health check failed: %v", err) + } + fmt.Println("Server is alive and responding") + // List available tools if the server supports them if serverInfo.Capabilities.Tools != nil { fmt.Println("Fetching available tools...") toolsRequest := mcp.ListToolsRequest{} toolsResult, err := c.ListTools(ctx, toolsRequest) if err != nil { - log.Printf("Failed to list tools: %v", err) - } else { - fmt.Printf("Server has %d tools available\n", len(toolsResult.Tools)) - for i, tool := range toolsResult.Tools { - fmt.Printf(" %d. %s - %s\n", i+1, tool.Name, tool.Description) - } + log.Fatalf("Failed to list tools: %v", err) + } + fmt.Printf("Server has %d tools available\n", len(toolsResult.Tools)) + for i, tool := range toolsResult.Tools { + fmt.Printf(" %d. %s - %s\n", i+1, tool.Name, tool.Description) } } @@ -142,12 +148,11 @@ func main() { resourcesRequest := mcp.ListResourcesRequest{} resourcesResult, err := c.ListResources(ctx, resourcesRequest) if err != nil { - log.Printf("Failed to list resources: %v", err) - } else { - fmt.Printf("Server has %d resources available\n", len(resourcesResult.Resources)) - for i, resource := range resourcesResult.Resources { - fmt.Printf(" %d. %s - %s\n", i+1, resource.URI, resource.Name) - } + log.Fatalf("Failed to list resources: %v", err) + } + fmt.Printf("Server has %d resources available\n", len(resourcesResult.Resources)) + for i, resource := range resourcesResult.Resources { + fmt.Printf(" %d. %s - %s\n", i+1, resource.URI, resource.Name) } } From 5298f646e70eca25c78c4e6465e880db56881d47 Mon Sep 17 00:00:00 2001 From: Darragh O'Reilly Date: Tue, 2 Sep 2025 08:56:09 +0100 Subject: [PATCH 7/9] oauth client example: skip DCR if MCP_CLIENT_ID is set (#549) If the user has set the MCP_CLIENT_ID environmental variable then they don't want Dynamic Client Registration and it should be skipped. --- examples/oauth_client/main.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/oauth_client/main.go b/examples/oauth_client/main.go index 639d8cb7..27d3b618 100644 --- a/examples/oauth_client/main.go +++ b/examples/oauth_client/main.go @@ -129,9 +129,11 @@ func maybeAuthorize(err error) { log.Fatalf("Failed to generate state: %v", err) } - err = oauthHandler.RegisterClient(context.Background(), "mcp-go-oauth-example") - if err != nil { - log.Fatalf("Failed to register client: %v", err) + if oauthHandler.GetClientID() == "" { + err = oauthHandler.RegisterClient(context.Background(), "mcp-go-oauth-example") + if err != nil { + log.Fatalf("Failed to register client: %v", err) + } } // Get the authorization URL From 35fc389d56010b24ef49f966868b8b948d2af48e Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Tue, 2 Sep 2025 10:59:42 +0300 Subject: [PATCH 8/9] Format --- client/transport/sse.go | 14 ++-- client/transport/streamable_http.go | 12 +-- .../streamable_http_sampling_test.go | 84 +++++++++---------- examples/sampling_client/main.go | 4 +- examples/sampling_http_client/main.go | 10 +-- examples/sampling_http_server/main.go | 2 +- server/sampling.go | 2 +- server/sampling_test.go | 12 +-- server/streamable_http_sampling_test.go | 2 +- 9 files changed, 71 insertions(+), 71 deletions(-) diff --git a/client/transport/sse.go b/client/transport/sse.go index 70a39190..305c9316 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -36,12 +36,12 @@ type SSE struct { headerFunc HTTPHeaderFunc logger util.Logger - started atomic.Bool - closed atomic.Bool - cancelSSEStream context.CancelFunc - protocolVersion atomic.Value // string - onConnectionLost func(error) - connectionLostMu sync.RWMutex + started atomic.Bool + closed atomic.Bool + cancelSSEStream context.CancelFunc + protocolVersion atomic.Value // string + onConnectionLost func(error) + connectionLostMu sync.RWMutex // OAuth support oauthHandler *OAuthHandler @@ -220,7 +220,7 @@ func (c *SSE) readSSE(reader io.ReadCloser) { c.connectionLostMu.RLock() handler := c.onConnectionLost c.connectionLostMu.RUnlock() - + if handler != nil { // This is not actually an error - HTTP2 idle timeout disconnection handler(err) diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 268aeb34..40a5bc69 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -605,7 +605,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) { connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second) err := c.createGETConnectionToServer(connectCtx) cancel() - + if errors.Is(err, ErrGetMethodNotAllowed) { // server does not support listening c.logger.Errorf("server does not support listening") @@ -621,7 +621,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) { if err != nil { c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err) } - + // Use context-aware sleep select { case <-time.After(retryInterval): @@ -704,15 +704,15 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON // Create a new context with timeout for request handling, respecting parent context requestCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - + response, err := handler(requestCtx, request) if err != nil { c.logger.Errorf("error handling request %s: %v", request.Method, err) - + // Determine appropriate JSON-RPC error code based on error type var errorCode int var errorMessage string - + // Check for specific sampling-related errors if errors.Is(err, context.Canceled) { errorCode = -32800 // Request cancelled @@ -731,7 +731,7 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON errorMessage = err.Error() } } - + // Send error response errorResponse := &JSONRPCResponse{ JSONRPC: "2.0", diff --git a/client/transport/streamable_http_sampling_test.go b/client/transport/streamable_http_sampling_test.go index edba61ea..4a38f280 100644 --- a/client/transport/streamable_http_sampling_test.go +++ b/client/transport/streamable_http_sampling_test.go @@ -16,27 +16,27 @@ import ( // TestStreamableHTTP_SamplingFlow tests the complete sampling flow with HTTP transport func TestStreamableHTTP_SamplingFlow(t *testing.T) { - // Create simple test server + // Create simple test server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Just respond OK to any requests w.WriteHeader(http.StatusOK) })) defer server.Close() - + // Create HTTP client transport client, err := NewStreamableHTTP(server.URL) if err != nil { t.Fatalf("Failed to create client: %v", err) } defer client.Close() - + // Set up sampling request handler var handledRequest *JSONRPCRequest handlerCalled := make(chan struct{}) client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { handledRequest = &request close(handlerCalled) - + // Simulate sampling handler response result := map[string]any{ "role": "assistant", @@ -47,25 +47,25 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { "model": "test-model", "stopReason": "stop_sequence", } - + resultBytes, _ := json.Marshal(result) - + return &JSONRPCResponse{ JSONRPC: "2.0", ID: request.ID, Result: resultBytes, }, nil }) - + // Start the client ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - + err = client.Start(ctx) if err != nil { t.Fatalf("Failed to start client: %v", err) } - + // Test direct request handling (simulating a sampling request) samplingRequest := JSONRPCRequest{ JSONRPC: "2.0", @@ -83,10 +83,10 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { }, }, } - + // Directly test request handling client.handleIncomingRequest(ctx, samplingRequest) - + // Wait for handler to be called select { case <-handlerCalled: @@ -94,12 +94,12 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { case <-time.After(1 * time.Second): t.Fatal("Handler was not called within timeout") } - + // Verify the request was handled if handledRequest == nil { t.Fatal("Sampling request was not handled") } - + if handledRequest.Method != string(mcp.MethodSamplingCreateMessage) { t.Errorf("Expected method %s, got %s", mcp.MethodSamplingCreateMessage, handledRequest.Method) } @@ -109,7 +109,7 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { var errorHandled sync.WaitGroup errorHandled.Add(1) - + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { var body map[string]any @@ -118,7 +118,7 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { w.WriteHeader(http.StatusOK) return } - + // Check if this is an error response if errorField, ok := body["error"]; ok { errorMap := errorField.(map[string]any) @@ -132,25 +132,25 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { w.WriteHeader(http.StatusOK) })) defer server.Close() - + client, err := NewStreamableHTTP(server.URL) if err != nil { t.Fatalf("Failed to create client: %v", err) } defer client.Close() - + // Set up request handler that returns an error client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { return nil, fmt.Errorf("sampling failed") }) - + // Start the client ctx := context.Background() err = client.Start(ctx) if err != nil { t.Fatalf("Failed to start client: %v", err) } - + // Simulate incoming sampling request samplingRequest := JSONRPCRequest{ JSONRPC: "2.0", @@ -158,10 +158,10 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { Method: string(mcp.MethodSamplingCreateMessage), Params: map[string]any{}, } - + // This should trigger error handling client.handleIncomingRequest(ctx, samplingRequest) - + // Wait for error to be handled errorHandled.Wait() } @@ -170,7 +170,7 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { var errorReceived bool errorReceivedChan := make(chan struct{}) - + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { var body map[string]any @@ -179,12 +179,12 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { w.WriteHeader(http.StatusOK) return } - + // Check if this is an error response with method not found if errorField, ok := body["error"]; ok { errorMap := errorField.(map[string]any) if code, ok := errorMap["code"].(float64); ok && code == -32601 { - if message, ok := errorMap["message"].(string); ok && + if message, ok := errorMap["message"].(string); ok && strings.Contains(message, "no handler configured") { errorReceived = true close(errorReceivedChan) @@ -195,21 +195,21 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { w.WriteHeader(http.StatusOK) })) defer server.Close() - + client, err := NewStreamableHTTP(server.URL) if err != nil { t.Fatalf("Failed to create client: %v", err) } defer client.Close() - + // Don't set any request handler - + ctx := context.Background() err = client.Start(ctx) if err != nil { t.Fatalf("Failed to start client: %v", err) } - + // Simulate incoming sampling request samplingRequest := JSONRPCRequest{ JSONRPC: "2.0", @@ -217,10 +217,10 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { Method: string(mcp.MethodSamplingCreateMessage), Params: map[string]any{}, } - + // This should trigger "method not found" error client.handleIncomingRequest(ctx, samplingRequest) - + // Wait for error to be received select { case <-errorReceivedChan: @@ -228,7 +228,7 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { case <-time.After(1 * time.Second): t.Fatal("Method not found error was not received within timeout") } - + if !errorReceived { t.Error("Expected method not found error, but didn't receive it") } @@ -241,13 +241,13 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { t.Fatalf("Failed to create client: %v", err) } defer client.Close() - + // Verify it implements BidirectionalInterface _, ok := any(client).(BidirectionalInterface) if !ok { t.Error("StreamableHTTP should implement BidirectionalInterface") } - + // Test SetRequestHandler handlerSet := false handlerSetChan := make(chan struct{}) @@ -256,7 +256,7 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { close(handlerSetChan) return nil, nil }) - + // Verify handler was set by triggering it ctx := context.Background() client.handleIncomingRequest(ctx, JSONRPCRequest{ @@ -264,7 +264,7 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { ID: mcp.NewRequestId(1), Method: "test", }) - + // Wait for handler to be called select { case <-handlerSetChan: @@ -272,7 +272,7 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { case <-time.After(1 * time.Second): t.Fatal("Handler was not called within timeout") } - + if !handlerSet { t.Error("Request handler was not properly set or called") } @@ -315,16 +315,16 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { // Track which requests have been received and their completion order var requestOrder []int var orderMutex sync.Mutex - + // Set up request handler that simulates different processing times client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { // Extract request ID to determine processing time requestIDValue := request.ID.Value() - + var delay time.Duration var responseText string var requestNum int - + // First request (ID 1) takes longer, second request (ID 2) completes faster if requestIDValue == int64(1) { delay = 100 * time.Millisecond @@ -341,7 +341,7 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { // Simulate processing time time.Sleep(delay) - + // Record completion order orderMutex.Lock() requestOrder = append(requestOrder, requestNum) @@ -428,7 +428,7 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { // Verify completion order: request 2 should complete first orderMutex.Lock() defer orderMutex.Unlock() - + if len(requestOrder) != 2 { t.Fatalf("Expected 2 completed requests, got %d", len(requestOrder)) } @@ -493,4 +493,4 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { } } } -} \ No newline at end of file +} diff --git a/examples/sampling_client/main.go b/examples/sampling_client/main.go index 093b5981..a655fde6 100644 --- a/examples/sampling_client/main.go +++ b/examples/sampling_client/main.go @@ -95,7 +95,7 @@ func main() { // Setup graceful shutdown sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - + // Create a context that cancels on signal ctx, cancel := context.WithCancel(ctx) go func() { @@ -103,7 +103,7 @@ func main() { log.Println("Received shutdown signal, closing client...") cancel() }() - + // Move defer after error checking defer func() { if err := mcpClient.Close(); err != nil { diff --git a/examples/sampling_http_client/main.go b/examples/sampling_http_client/main.go index 98817e6f..946223f7 100644 --- a/examples/sampling_http_client/main.go +++ b/examples/sampling_http_client/main.go @@ -63,7 +63,7 @@ func main() { log.Fatalf("Failed to create HTTP transport: %v", err) } defer httpTransport.Close() - + // Create client with sampling support mcpClient := client.NewClient( httpTransport, @@ -81,7 +81,7 @@ func main() { initRequest := mcp.InitializeRequest{ Params: mcp.InitializeParams{ ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, - Capabilities: mcp.ClientCapabilities{ + Capabilities: mcp.ClientCapabilities{ // Sampling capability will be automatically added by the client }, ClientInfo: mcp.Implementation{ @@ -90,7 +90,7 @@ func main() { }, }, } - + _, err = mcpClient.Initialize(ctx, initRequest) if err != nil { log.Fatalf("Failed to initialize MCP session: %v", err) @@ -102,7 +102,7 @@ func main() { // In a real application, you would keep the client running to handle sampling requests // For this example, we'll just demonstrate that it's working - + // Keep the client running (in a real app, you'd have your main application logic here) sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) @@ -113,4 +113,4 @@ func main() { case <-sigChan: log.Println("Received shutdown signal") } -} \ No newline at end of file +} diff --git a/examples/sampling_http_server/main.go b/examples/sampling_http_server/main.go index 95a2bf29..a178acce 100644 --- a/examples/sampling_http_server/main.go +++ b/examples/sampling_http_server/main.go @@ -147,4 +147,4 @@ func main() { if err := httpServer.Start(":8080"); err != nil { log.Fatalf("Server failed to start: %v", err) } -} \ No newline at end of file +} diff --git a/server/sampling.go b/server/sampling.go index 4423ccf5..2118db15 100644 --- a/server/sampling.go +++ b/server/sampling.go @@ -12,7 +12,7 @@ import ( func (s *MCPServer) EnableSampling() { s.capabilitiesMu.Lock() defer s.capabilitiesMu.Unlock() - + enabled := true s.capabilities.sampling = &enabled } diff --git a/server/sampling_test.go b/server/sampling_test.go index fbecdd70..012bf2fd 100644 --- a/server/sampling_test.go +++ b/server/sampling_test.go @@ -116,7 +116,7 @@ func TestMCPServer_RequestSampling_Success(t *testing.T) { func TestMCPServer_EnableSampling_SetsCapability(t *testing.T) { server := NewMCPServer("test", "1.0.0") - + // Verify sampling capability is not set initially ctx := context.Background() initRequest := mcp.InitializeRequest{ @@ -129,25 +129,25 @@ func TestMCPServer_EnableSampling_SetsCapability(t *testing.T) { Capabilities: mcp.ClientCapabilities{}, }, } - + result, err := server.handleInitialize(ctx, 1, initRequest) if err != nil { t.Fatalf("unexpected error: %v", err) } - + if result.Capabilities.Sampling != nil { t.Error("sampling capability should not be set before EnableSampling() is called") } - + // Enable sampling server.EnableSampling() - + // Verify sampling capability is now set result, err = server.handleInitialize(ctx, 2, initRequest) if err != nil { t.Fatalf("unexpected error after EnableSampling(): %v", err) } - + if result.Capabilities.Sampling == nil { t.Error("sampling capability should be set after EnableSampling() is called") } diff --git a/server/streamable_http_sampling_test.go b/server/streamable_http_sampling_test.go index 4cf57838..50be27fa 100644 --- a/server/streamable_http_sampling_test.go +++ b/server/streamable_http_sampling_test.go @@ -213,4 +213,4 @@ func TestStreamableHTTPServer_SamplingQueueFull(t *testing.T) { if !strings.Contains(err.Error(), "queue is full") { t.Errorf("Expected queue full error, got: %v", err) } -} \ No newline at end of file +} From 3d1bfcabfa50c0d92e5302bff9c4d7bc350f9913 Mon Sep 17 00:00:00 2001 From: Alejandro Borbolla <52978371+alex210501@users.noreply.github.com> Date: Tue, 2 Sep 2025 17:20:27 +0200 Subject: [PATCH 9/9] fix(tool): Do not return empty `outputSchema` (#573) * fix(tool): Do not return empty `outputSchema` If the output schema is not specify, we do not return it as we break the MCP specification that are expecting the following format: ```json "outputSchema": { "type": "object", "properties": {}, "required": [] } ``` * fix(tool): typo in test --- mcp/tools.go | 2 +- mcp/tools_test.go | 60 +++++++++++++++++++++++++++++++++-------------- 2 files changed, 43 insertions(+), 19 deletions(-) diff --git a/mcp/tools.go b/mcp/tools.go index 185aefa6..493e8c77 100644 --- a/mcp/tools.go +++ b/mcp/tools.go @@ -607,7 +607,7 @@ func (t Tool) MarshalJSON() ([]byte, error) { return nil, fmt.Errorf("tool %s has both OutputSchema and RawOutputSchema set: %w", t.Name, errToolSchemaConflict) } m["outputSchema"] = t.RawOutputSchema - } else { + } else if t.OutputSchema.Type != "" { // If no output schema is specified, do not return anything m["outputSchema"] = t.OutputSchema } diff --git a/mcp/tools_test.go b/mcp/tools_test.go index 270bf64b..4bb07aa5 100644 --- a/mcp/tools_test.go +++ b/mcp/tools_test.go @@ -586,27 +586,51 @@ func TestToolWithOutputSchema(t *testing.T) { Email string `json:"email,omitempty" jsonschema_description:"Email address"` } - tool := NewTool("test_tool", - WithDescription("Test tool with output schema"), - WithOutputSchema[TestOutput](), - WithString("input", Required()), - ) - - // Check that RawOutputSchema was set - assert.NotNil(t, tool.OutputSchema) + tests := []struct { + name string + tool Tool + expectedOutputSchema bool + }{ + { + name: "default behavior", + tool: NewTool("test_tool", + WithDescription("Test tool with output schema"), + WithOutputSchema[TestOutput](), + WithString("input", Required()), + ), + expectedOutputSchema: true, + }, + { + name: "no output schema is set", + tool: NewTool("test_tool", + WithDescription("Test tool with no output schema"), + WithString("input", Required()), + ), + expectedOutputSchema: false, + }, + } - // Marshal and verify structure - data, err := json.Marshal(tool) - assert.NoError(t, err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Marshal and verify structure + data, err := json.Marshal(tt.tool) + assert.NoError(t, err) - var toolData map[string]any - err = json.Unmarshal(data, &toolData) - assert.NoError(t, err) + var toolData map[string]any + err = json.Unmarshal(data, &toolData) + assert.NoError(t, err) - // Verify outputSchema exists - outputSchema, exists := toolData["outputSchema"] - assert.True(t, exists) - assert.NotNil(t, outputSchema) + // Verify outputSchema exists + outputSchema, exists := toolData["outputSchema"] + if tt.expectedOutputSchema { + assert.True(t, exists) + assert.NotNil(t, outputSchema) + } else { + assert.False(t, exists) + assert.Nil(t, outputSchema) + } + }) + } } // TestNewToolResultStructured tests that the NewToolResultStructured function