From 06ae01b1418203544b1400ade19e45dfc8e0ed17 Mon Sep 17 00:00:00 2001 From: cryo Date: Thu, 3 Jul 2025 00:38:25 +0800 Subject: [PATCH 001/221] mcp/streamable: fix typos (#83)   This PR updates the comments for `StreamableServerTransport.streamRequests` to correct typos and improve clarity. --- mcp/streamable.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index da950fb2..db3add85 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -223,13 +223,13 @@ type StreamableServerTransport struct { // TODO(rfindley): clean up once requests are handled. requestStreams map[JSONRPCID]streamID - // outstandingRequests tracks the set of unanswered incoming RPCs for each logical + // streamRequests tracks the set of unanswered incoming RPCs for each logical // stream. // // When the server has responded to each request, the stream should be // closed. // - // Lifecycle: outstandingRequests values persist as until the requests have been + // Lifecycle: streamRequests values persist as until the requests have been // replied to by the server. Notably, NOT until they are sent to an HTTP // response, as delivery is not guaranteed. streamRequests map[streamID]map[JSONRPCID]struct{} From ba7a06461a86b7d79c634758c41db3bdf8a1fbca Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 2 Jul 2025 15:11:26 -0400 Subject: [PATCH 002/221] examples/rate-limiting: change Go version to 1.23 (#89) The version is 1.25, which doesn't match the main module's go.mod. --- examples/rate-limiting/go.mod | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/rate-limiting/go.mod b/examples/rate-limiting/go.mod index 5ec49ddc..f3cf7aa1 100644 --- a/examples/rate-limiting/go.mod +++ b/examples/rate-limiting/go.mod @@ -1,6 +1,8 @@ module github.com/modelcontextprotocol/go-sdk/examples/rate-limiting -go 1.25 +go 1.23.0 + +toolchain go1.24.4 require ( github.com/modelcontextprotocol/go-sdk v0.0.0-20250625185707-09181c2c2e89 From c6b229c855e5dec859ff598c7dc01a4b9b05106b Mon Sep 17 00:00:00 2001 From: Martin Emde Date: Fri, 4 Jul 2025 10:46:41 -0700 Subject: [PATCH 003/221] mcp: render blank text field in TextContent instead of omitempty (#91) The text field is required, but current implementation omits the text field when it is a blank string. Some clients fail to parse the response if the text field is omitted. A blank text result is a valid MCP response. Claude desktop error message: ``` ClaudeAiToolResultRequest.content.0.text.text: Field required ``` The solution is to return the field even when it is blank by removing `omitempty`. --- mcp/content.go | 14 ++++++++++++-- mcp/content_test.go | 8 ++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/mcp/content.go b/mcp/content.go index ed7f6f99..61111a99 100644 --- a/mcp/content.go +++ b/mcp/content.go @@ -2,6 +2,9 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. +// TODO(findleyr): update JSON marshalling of all content types to preserve required fields. +// (See [TextContent.MarshalJSON], which handles this for text content). + package mcp import ( @@ -25,12 +28,19 @@ type TextContent struct { } func (c *TextContent) MarshalJSON() ([]byte, error) { - return json.Marshal(&wireContent{ + // Custom wire format to ensure the required "text" field is always included, even when empty. + wire := struct { + Type string `json:"type"` + Text string `json:"text"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` + }{ Type: "text", Text: c.Text, Meta: c.Meta, Annotations: c.Annotations, - }) + } + return json.Marshal(wire) } func (c *TextContent) fromWire(wire *wireContent) { diff --git a/mcp/content_test.go b/mcp/content_test.go index 5ee6f66c..8c4e5306 100644 --- a/mcp/content_test.go +++ b/mcp/content_test.go @@ -22,6 +22,14 @@ func TestContent(t *testing.T) { &mcp.TextContent{Text: "hello"}, `{"type":"text","text":"hello"}`, }, + { + &mcp.TextContent{Text: ""}, + `{"type":"text","text":""}`, + }, + { + &mcp.TextContent{}, + `{"type":"text","text":""}`, + }, { &mcp.TextContent{ Text: "hello", From 328a25d503562bfc89037e98e08fddfc0073dfc2 Mon Sep 17 00:00:00 2001 From: Martin Emde Date: Fri, 4 Jul 2025 11:35:33 -0700 Subject: [PATCH 004/221] mcp: ensure meta is not skipped on blob ResourceContents (#96) I noticed this while working on #95. Seems Meta was accidentally skipped in the custom wire format struct. --- mcp/content.go | 2 ++ mcp/content_test.go | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/mcp/content.go b/mcp/content.go index 61111a99..fd027cf8 100644 --- a/mcp/content.go +++ b/mcp/content.go @@ -187,10 +187,12 @@ func (r ResourceContents) MarshalJSON() ([]byte, error) { URI string `json:"uri,omitempty"` MIMEType string `json:"mimeType,omitempty"` Blob []byte `json:"blob"` + Meta Meta `json:"_meta,omitempty"` }{ URI: r.URI, MIMEType: r.MIMEType, Blob: r.Blob, + Meta: r.Meta, } return json.Marshal(br) } diff --git a/mcp/content_test.go b/mcp/content_test.go index 8c4e5306..7a549bea 100644 --- a/mcp/content_test.go +++ b/mcp/content_test.go @@ -154,6 +154,10 @@ func TestEmbeddedResource(t *testing.T) { &mcp.ResourceContents{URI: "u", Blob: []byte{1}}, `{"uri":"u","blob":"AQ=="}`, }, + { + &mcp.ResourceContents{URI: "u", MIMEType: "m", Blob: []byte{1}, Meta: mcp.Meta{"key": "value"}}, + `{"uri":"u","mimeType":"m","blob":"AQ==","_meta":{"key":"value"}}`, + }, } { data, err := json.Marshal(tt.rc) if err != nil { From aebd2449813d66cf742438ea37bdd8662fc10c30 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 7 Jul 2025 14:54:52 -0400 Subject: [PATCH 005/221] mcp: new API for features (#88) BREAKING CHANGE Change the API for adding tools, prompts, resources and resource templates to a server. The ServerXXX types are gone along with the plural Add methods (AddTools, AddPrompts etc.) Instead there are single Add methods (AddTool, AddPrompt). In addition, instead of NewServerTool, there is AddTool[In, Out]. Fixes #73. --- README.md | 24 +++- design/design.md | 139 +++++++------------ examples/hello/main.go | 24 ++-- examples/sse/main.go | 6 +- internal/readme/README.src.md | 15 +++ internal/readme/server/server.go | 9 +- mcp/client_list_test.go | 58 ++++---- mcp/cmd_test.go | 3 +- mcp/features_test.go | 12 +- mcp/mcp_test.go | 135 +++++++------------ mcp/prompt.go | 7 +- mcp/resource.go | 20 +-- mcp/server.go | 170 ++++++++++-------------- mcp/server_example_test.go | 2 +- mcp/shared_test.go | 11 +- mcp/sse_example_test.go | 2 +- mcp/sse_test.go | 2 +- mcp/streamable_test.go | 6 +- mcp/tool.go | 221 ++++++------------------------- mcp/tool_test.go | 112 ++++++++-------- 20 files changed, 366 insertions(+), 612 deletions(-) diff --git a/README.md b/README.md index d4900674..ce4d9f11 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,21 @@ # MCP Go SDK +***BREAKING CHANGES*** + +The latest version contains breaking changes: + +- Server.AddTools is replaced by Server.AddTool. + +- NewServerTool is replaced by AddTool. AddTool takes a Tool rather than a name and description, so you can + set any field on the Tool that you want before associating it with a handler. + +- Tool options have been removed. If you don't want AddTool to infer a JSON Schema for you, you can construct one + as a struct literal, or using any other code that suits you. + +- AddPrompts, AddResources and AddResourceTemplates are similarly replaced by singular methods which pair the + feature with a handler. The ServerXXX types have been removed. + [![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) This repository contains an unreleased implementation of the official Go @@ -99,7 +114,7 @@ import ( ) type HiParams struct { - Name string `json:"name"` + Name string `json:"name", mcp:"the name of the person to greet"` } func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[HiParams]) (*mcp.CallToolResultFor[any], error) { @@ -111,11 +126,8 @@ func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParam func main() { // Create a server with a single tool. server := mcp.NewServer("greeter", "v1.0.0", nil) - server.AddTools( - mcp.NewServerTool("greet", "say hi", SayHi, mcp.Input( - mcp.Property("name", mcp.Description("the name of the person to greet")), - )), - ) + + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) // Run the server over stdin/stdout, until the client disconnects if err := server.Run(context.Background(), mcp.NewStdioTransport()); err != nil { log.Fatal(err) diff --git a/design/design.md b/design/design.md index b52e9c10..610de399 100644 --- a/design/design.md +++ b/design/design.md @@ -372,12 +372,11 @@ A server that can handle that client call would look like this: ```go // Create a server with a single tool. server := mcp.NewServer("greeter", "v1.0.0", nil) -server.AddTools(mcp.NewServerTool("greet", "say hi", SayHi)) +mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) // Run the server over stdin/stdout, until the client disconnects. -transport := mcp.NewStdioTransport() -session, err := server.Connect(ctx, transport) -... -return session.Wait() +if err := server.Run(context.Background(), mcp.NewStdioTransport()); err != nil { + log.Fatal(err) +} ``` For convenience, we provide `Server.Run` to handle the common case of running a session until the client disconnects: @@ -603,14 +602,14 @@ type ClientOptions struct { ### Tools -A `Tool` is a logical MCP tool, generated from the MCP spec, and a `ServerTool` is a tool bound to a tool handler. +A `Tool` is a logical MCP tool, generated from the MCP spec. -A tool handler accepts `CallToolParams` and returns a `CallToolResult`. However, since we want to bind tools to Go input types, it is convenient in associated APIs to make `CallToolParams` generic, with a type parameter `TArgs` for the tool argument type. This allows tool APIs to manage the marshalling and unmarshalling of tool inputs for their caller. The bound `ServerTool` type expects a `json.RawMessage` for its tool arguments, but the `NewServerTool` constructor described below provides a mechanism to bind a typed handler. +A tool handler accepts `CallToolParams` and returns a `CallToolResult`. However, since we want to bind tools to Go input types, it is convenient in associated APIs to have a generic version of `CallToolParams`, with a type parameter `In` for the tool argument type, as well as a generic version of for `CallToolResult`. This allows tool APIs to manage the marshalling and unmarshalling of tool inputs for their caller. ```go -type CallToolParams[TArgs any] struct { +type CallToolParamsFor[In any] struct { Meta Meta `json:"_meta,omitempty"` - Arguments TArgs `json:"arguments,omitempty"` + Arguments In `json:"arguments,omitempty"` Name string `json:"name"` } @@ -621,23 +620,31 @@ type Tool struct { Name string `json:"name"` } -type ToolHandler[TArgs] func(context.Context, *ServerSession, *CallToolParams[TArgs]) (*CallToolResult, error) +type ToolHandlerFor[In, Out any] func(context.Context, *ServerSession, *CallToolParamsFor[In]) (*CallToolResultFor[Out], error) +type ToolHandler = ToolHandlerFor[map[string]any, any] +``` -type ServerTool struct { - Tool Tool - Handler ToolHandler[json.RawMessage] -} +Add tools to a server with the `AddTool` method or function. The function is generic and infers schemas from the handler +arguments: + +```go +func (s *Server) AddTool(t *Tool, h ToolHandler) +func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) ``` -Add tools to a server with `AddTools`: +```go +mcp.AddTool(server, &mcp.Tool{Name: "add", Description: "add numbers"}, addHandler) +mcp.AddTool(server, &mcp.Tool{Name: "subtract", Description: "subtract numbers"}, subHandler) +``` +The `AddTool` method requires an input schema, and optionally an output one. It will not modify them. +The handler should accept a `CallToolParams` and return a `CallToolResult`. ```go -server.AddTools( - mcp.NewServerTool("add", "add numbers", addHandler), - mcp.NewServerTool("subtract, subtract numbers", subHandler)) +t := &Tool{Name: ..., Description: ..., InputSchema: &jsonschema.Schema{...}} +server.AddTool(t, myHandler) ``` -Remove them by name with `RemoveTools`: +Tools can be removed by name with `RemoveTools`: ```go server.RemoveTools("add", "subtract") @@ -650,53 +657,30 @@ A tool's input schema, expressed as a [JSON Schema](https://json-schema.org), pr Both of these have their advantages and disadvantages. Reflection is nice, because it allows you to bind directly to a Go API, and means that the JSON schema of your API is compatible with your Go types by construction. It also means that concerns like parsing and validation can be handled automatically. However, it can become cumbersome to express the full breadth of JSON schema using Go types or struct tags, and sometimes you want to express things that aren’t naturally modeled by Go types, like unions. Explicit schemas are simple and readable, and give the caller full control over their tool definition, but involve significant boilerplate. -We have found that a hybrid model works well, where the _initial_ schema is derived using reflection, but any customization on top of that schema is applied using variadic options. We achieve this using a `NewServerTool` helper, which generates the schema from the input type, and wraps the handler to provide parsing and validation. The schema (and potentially other features) can be customized using ToolOptions. - -```go -// NewServerTool creates a Tool using reflection on the given handler. -func NewServerTool[TArgs any](name, description string, handler ToolHandler[TArgs], opts …ToolOption) *ServerTool +We provide both ways. The `jsonschema.For[T]` function will infer a schema, and it is called by the `AddTool` generic function. +Users can also call it themselves, or build a schema directly as a struct literal. They can still use the `AddTool` function to +create a typed handler, since `AddTool` doesn't touch schemas that are already present. -type ToolOption interface { /* ... */ } -``` -`NewServerTool` determines the input schema for a Tool from the `TArgs` type. Each struct field that would be marshaled by `encoding/json.Marshal` becomes a property of the schema. The property is required unless the field's `json` tag specifies "omitempty" or "omitzero" (new in Go 1.24). For example, given this struct: +If the tool's `InputSchema` is nil, it is inferred from the `In` type parameter. If the `OutputSchema` is nil, it is inferred from the `Out` type parameter (unless `Out` is `any`). +For example, given this handler: ```go -struct { - Name string `json:"name"` - Count int `json:"count,omitempty"` - Choices []string - Password []byte `json:"-"` +type AddParams struct { + X int `json:"x"` + Y int `json:"y"` } -``` - -"name" and "Choices" are required, while "count" is optional. - -As of this writing, the only `ToolOption` is `Input`, which allows customizing the input schema of the tool using schema options. These schema options are recursive, in the sense that they may also be applied to properties. - -```go -func Input(...SchemaOption) ToolOption - -type Property(name string, opts ...SchemaOption) SchemaOption -type Description(desc string) SchemaOption -// etc. -``` - -For example: -```go -NewServerTool(name, description, handler, - Input(Property("count", Description("size of the inventory")))) +func addHandler(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[AddParams]) (*mcp.CallToolResultFor[int], error) { + return &mcp.CallToolResultFor[int]{StructuredContent: params.Arguments.X + params.Arguments.Y}, nil +} ``` -The most recent JSON Schema spec defines over 40 keywords. Providing them all as options would bloat the API despite the fact that most would be very rarely used. For less common keywords, use the `Schema` option to set the schema explicitly: - +You can add it to a server like this: ```go -NewServerTool(name, description, handler, - Input(Property("Choices", Schema(&jsonschema.Schema{UniqueItems: true})))) +mcp.AddTool(server, &mcp.Tool{Name: "add", Description: "add numbers"}, addHandler) ``` - -Schemas are validated on the server before the tool handler is called. +The input schema will be inferred from `AddParams`, and the output schema from `int`. Since all the fields of the Tool struct are exported, a Tool can also be created directly with assignment or a struct literal. @@ -718,15 +702,7 @@ For registering tools, we provide only `AddTools`; mcp-go's `SetTools`, `AddTool ### Prompts -Use `NewServerPrompt` to create a prompt. As with tools, prompt argument schemas can be inferred from a struct, or obtained from options. - -```go -func NewServerPrompt[TReq any](name, description string, - handler func(context.Context, *ServerSession, TReq) (*GetPromptResult, error), - opts ...PromptOption) *ServerPrompt -``` - -Use `AddPrompts` to add prompts to the server, and `RemovePrompts` +Use `AddPrompt` to add a prompt to the server, and `RemovePrompts` to remove them by name. ```go @@ -734,11 +710,12 @@ type codeReviewArgs struct { Code string `json:"code"` } -func codeReviewHandler(context.Context, *ServerSession, codeReviewArgs) {...} +func codeReviewHandler(context.Context, *ServerSession, *mcp.GetPromptParams) (*mcp.GetPromptResult, error) {...} -server.AddPrompts( - NewServerPrompt("code_review", "review code", codeReviewHandler, - Argument("code", Description("the code to review")))) +server.AddPrompt( + &mcp.Prompt{Name: "code_review", Description: "review code"}, + codeReviewHandler, +) server.RemovePrompts("code_review") ``` @@ -757,25 +734,11 @@ type ResourceHandler func(context.Context, *ServerSession, *ReadResourceParams) The arguments include the `ServerSession` so the handler can observe the client's roots. The handler should return the resource contents in a `ReadResourceResult`, calling either `NewTextResourceContents` or `NewBlobResourceContents`. If the handler omits the URI or MIME type, the server will populate them from the resource. -The `ServerResource` and `ServerResourceTemplate` types hold the association between the resource and its handler: - -```go -type ServerResource struct { - Resource Resource - Handler ResourceHandler -} - -type ServerResourceTemplate struct { - Template ResourceTemplate - Handler ResourceHandler -} -``` - -To add a resource or resource template to a server, users call the `AddResources` and `AddResourceTemplates` methods with one or more `ServerResource`s or `ServerResourceTemplate`s. We also provide methods to remove them. +To add a resource or resource template to a server, users call the `AddResource` and `AddResourceTemplate` methods. We also provide methods to remove them. ```go -func (*Server) AddResources(...*ServerResource) -func (*Server) AddResourceTemplates(...*ServerResourceTemplate) +func (*Server) AddResource(*Resource, ResourceHandler) +func (*Server) AddResourceTemplate(*ResourceTemplate, ResourceHandler) func (s *Server) RemoveResources(uris ...string) func (s *Server) RemoveResourceTemplates(uriTemplates ...string) @@ -796,9 +759,7 @@ Here is an example: ```go // Safely read "/public/puppies.txt". -s.AddResources(&mcp.ServerResource{ - Resource: mcp.Resource{URI: "file:///puppies.txt"}, - Handler: s.FileReadResourceHandler("/public")}) +s.AddResource(&mcp.Resource{URI: "file:///puppies.txt"}, s.FileReadResourceHandler("/public")) ``` Server sessions also support the spec methods `ListResources` and `ListResourceTemplates`, and the corresponding iterator methods `Resources` and `ResourceTemplates`. diff --git a/examples/hello/main.go b/examples/hello/main.go index 9af34cc3..4db20cc8 100644 --- a/examples/hello/main.go +++ b/examples/hello/main.go @@ -19,7 +19,7 @@ import ( var httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") type HiArgs struct { - Name string `json:"name"` + Name string `json:"name" mcp:"the name to say hi to"` } func SayHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[HiArgs]) (*mcp.CallToolResultFor[struct{}], error) { @@ -43,21 +43,13 @@ func main() { flag.Parse() server := mcp.NewServer("greeter", "v0.0.1", nil) - server.AddTools(mcp.NewServerTool("greet", "say hi", SayHi, mcp.Input( - mcp.Property("name", mcp.Description("the name to say hi to")), - ))) - server.AddPrompts(&mcp.ServerPrompt{ - Prompt: &mcp.Prompt{Name: "greet"}, - Handler: PromptHi, - }) - server.AddResources(&mcp.ServerResource{ - Resource: &mcp.Resource{ - Name: "info", - MIMEType: "text/plain", - URI: "embedded:info", - }, - Handler: handleEmbeddedResource, - }) + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) + server.AddPrompt(&mcp.Prompt{Name: "greet"}, PromptHi) + server.AddResource(&mcp.Resource{ + Name: "info", + MIMEType: "text/plain", + URI: "embedded:info", + }, handleEmbeddedResource) if *httpAddr != "" { handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { diff --git a/examples/sse/main.go b/examples/sse/main.go index 97ea1bd0..c93320ab 100644 --- a/examples/sse/main.go +++ b/examples/sse/main.go @@ -35,12 +35,12 @@ func main() { } server1 := mcp.NewServer("greeter1", "v0.0.1", nil) - server1.AddTools(mcp.NewServerTool("greet1", "say hi", SayHi)) + mcp.AddTool(server1, &mcp.Tool{Name: "greet1", Description: "say hi"}, SayHi) server2 := mcp.NewServer("greeter2", "v0.0.1", nil) - server2.AddTools(mcp.NewServerTool("greet2", "say hello", SayHi)) + mcp.AddTool(server2, &mcp.Tool{Name: "greet2", Description: "say hello"}, SayHi) - log.Printf("MCP servers serving at %s\n", *httpAddr) + log.Printf("MCP servers serving at %s", *httpAddr) handler := mcp.NewSSEHandler(func(request *http.Request) *mcp.Server { url := request.URL.Path log.Printf("Handling request for URL %s\n", url) diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index 629629a4..11d63110 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -1,5 +1,20 @@ # MCP Go SDK +***BREAKING CHANGES*** + +The latest version contains breaking changes: + +- Server.AddTools is replaced by Server.AddTool. + +- NewServerTool is replaced by AddTool. AddTool takes a Tool rather than a name and description, so you can + set any field on the Tool that you want before associating it with a handler. + +- Tool options have been removed. If you don't want AddTool to infer a JSON Schema for you, you can construct one + as a struct literal, or using any other code that suits you. + +- AddPrompts, AddResources and AddResourceTemplates are similarly replaced by singular methods which pair the + feature with a handler. The ServerXXX types have been removed. + [![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) This repository contains an unreleased implementation of the official Go diff --git a/internal/readme/server/server.go b/internal/readme/server/server.go index 534e0798..1fe211ea 100644 --- a/internal/readme/server/server.go +++ b/internal/readme/server/server.go @@ -13,7 +13,7 @@ import ( ) type HiParams struct { - Name string `json:"name"` + Name string `json:"name", mcp:"the name of the person to greet"` } func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[HiParams]) (*mcp.CallToolResultFor[any], error) { @@ -25,11 +25,8 @@ func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParam func main() { // Create a server with a single tool. server := mcp.NewServer("greeter", "v1.0.0", nil) - server.AddTools( - mcp.NewServerTool("greet", "say hi", SayHi, mcp.Input( - mcp.Property("name", mcp.Description("the name of the person to greet")), - )), - ) + + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) // Run the server over stdin/stdout, until the client disconnects if err := server.Run(context.Background(), mcp.NewStdioTransport()); err != nil { log.Fatal(err) diff --git a/mcp/client_list_test.go b/mcp/client_list_test.go index 7e6da95a..497a9cd0 100644 --- a/mcp/client_list_test.go +++ b/mcp/client_list_test.go @@ -22,12 +22,12 @@ func TestList(t *testing.T) { defer serverSession.Close() t.Run("tools", func(t *testing.T) { - toolA := mcp.NewServerTool("apple", "apple tool", SayHi) - toolB := mcp.NewServerTool("banana", "banana tool", SayHi) - toolC := mcp.NewServerTool("cherry", "cherry tool", SayHi) - tools := []*mcp.ServerTool{toolA, toolB, toolC} - wantTools := []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool} - server.AddTools(tools...) + var wantTools []*mcp.Tool + for _, name := range []string{"apple", "banana", "cherry"} { + t := &mcp.Tool{Name: name, Description: name + " tool"} + wantTools = append(wantTools, t) + mcp.AddTool(server, t, SayHi) + } t.Run("list", func(t *testing.T) { res, err := clientSession.ListTools(ctx, nil) if err != nil { @@ -43,12 +43,13 @@ func TestList(t *testing.T) { }) t.Run("resources", func(t *testing.T) { - resourceA := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://apple"}} - resourceB := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://banana"}} - resourceC := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://cherry"}} - wantResources := []*mcp.Resource{resourceA.Resource, resourceB.Resource, resourceC.Resource} - resources := []*mcp.ServerResource{resourceA, resourceB, resourceC} - server.AddResources(resources...) + var wantResources []*mcp.Resource + for _, name := range []string{"apple", "banana", "cherry"} { + r := &mcp.Resource{URI: "http://" + name} + wantResources = append(wantResources, r) + server.AddResource(r, nil) + } + t.Run("list", func(t *testing.T) { res, err := clientSession.ListResources(ctx, nil) if err != nil { @@ -64,15 +65,12 @@ func TestList(t *testing.T) { }) t.Run("templates", func(t *testing.T) { - resourceTmplA := &mcp.ServerResourceTemplate{ResourceTemplate: &mcp.ResourceTemplate{URITemplate: "http://apple/{x}"}} - resourceTmplB := &mcp.ServerResourceTemplate{ResourceTemplate: &mcp.ResourceTemplate{URITemplate: "http://banana/{x}"}} - resourceTmplC := &mcp.ServerResourceTemplate{ResourceTemplate: &mcp.ResourceTemplate{URITemplate: "http://cherry/{x}"}} - wantResourceTemplates := []*mcp.ResourceTemplate{ - resourceTmplA.ResourceTemplate, resourceTmplB.ResourceTemplate, - resourceTmplC.ResourceTemplate, + var wantResourceTemplates []*mcp.ResourceTemplate + for _, name := range []string{"apple", "banana", "cherry"} { + rt := &mcp.ResourceTemplate{URITemplate: "http://" + name + "/{x}"} + wantResourceTemplates = append(wantResourceTemplates, rt) + server.AddResourceTemplate(rt, nil) } - resourceTemplates := []*mcp.ServerResourceTemplate{resourceTmplA, resourceTmplB, resourceTmplC} - server.AddResourceTemplates(resourceTemplates...) t.Run("list", func(t *testing.T) { res, err := clientSession.ListResourceTemplates(ctx, nil) if err != nil { @@ -88,12 +86,12 @@ func TestList(t *testing.T) { }) t.Run("prompts", func(t *testing.T) { - promptA := newServerPrompt("apple", "apple prompt") - promptB := newServerPrompt("banana", "banana prompt") - promptC := newServerPrompt("cherry", "cherry prompt") - wantPrompts := []*mcp.Prompt{promptA.Prompt, promptB.Prompt, promptC.Prompt} - prompts := []*mcp.ServerPrompt{promptA, promptB, promptC} - server.AddPrompts(prompts...) + var wantPrompts []*mcp.Prompt + for _, name := range []string{"apple", "banana", "cherry"} { + p := &mcp.Prompt{Name: name, Description: name + " prompt"} + wantPrompts = append(wantPrompts, p) + server.AddPrompt(p, testPromptHandler) + } t.Run("list", func(t *testing.T) { res, err := clientSession.ListPrompts(ctx, nil) if err != nil { @@ -123,14 +121,6 @@ func testIterator[T any](ctx context.Context, t *testing.T, seq iter.Seq2[*T, er } } -// testPromptHandler is used for type inference newServerPrompt. func testPromptHandler(context.Context, *mcp.ServerSession, *mcp.GetPromptParams) (*mcp.GetPromptResult, error) { panic("not implemented") } - -func newServerPrompt(name, desc string) *mcp.ServerPrompt { - return &mcp.ServerPrompt{ - Prompt: &mcp.Prompt{Name: name, Description: desc}, - Handler: testPromptHandler, - } -} diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index f66423d6..496694a5 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -31,8 +31,7 @@ func runServer() { ctx := context.Background() server := mcp.NewServer("greeter", "v0.0.1", nil) - server.AddTools(mcp.NewServerTool("greet", "say hi", SayHi)) - + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) if err := server.Run(ctx, mcp.NewStdioTransport()); err != nil { log.Fatal(err) } diff --git a/mcp/features_test.go b/mcp/features_test.go index 5ffbce8c..e0165ecb 100644 --- a/mcp/features_test.go +++ b/mcp/features_test.go @@ -27,9 +27,9 @@ func SayHi(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[Say } func TestFeatureSetOrder(t *testing.T) { - toolA := NewServerTool("apple", "apple tool", SayHi).Tool - toolB := NewServerTool("banana", "banana tool", SayHi).Tool - toolC := NewServerTool("cherry", "cherry tool", SayHi).Tool + toolA := &Tool{Name: "apple", Description: "apple tool"} + toolB := &Tool{Name: "banana", Description: "banana tool"} + toolC := &Tool{Name: "cherry", Description: "cherry tool"} testCases := []struct { tools []*Tool @@ -52,9 +52,9 @@ func TestFeatureSetOrder(t *testing.T) { } func TestFeatureSetAbove(t *testing.T) { - toolA := NewServerTool("apple", "apple tool", SayHi).Tool - toolB := NewServerTool("banana", "banana tool", SayHi).Tool - toolC := NewServerTool("cherry", "cherry tool", SayHi).Tool + toolA := &Tool{Name: "apple", Description: "apple tool"} + toolB := &Tool{Name: "banana", Description: "banana tool"} + toolC := &Tool{Name: "cherry", Description: "cherry tool"} testCases := []struct { tools []*Tool diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 5f42b1b9..70c79b58 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -21,7 +21,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonschema" ) @@ -30,6 +29,9 @@ type hiParams struct { Name string } +// TODO(jba): after schemas are stateless (WIP), this can be a variable. +func greetTool() *Tool { return &Tool{Name: "greet", Description: "say hi"} } + func sayHi(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[hiParams]) (*CallToolResultFor[any], error) { if err := ss.Ping(ctx, nil); err != nil { return nil, fmt.Errorf("ping failed: %v", err) @@ -63,9 +65,31 @@ func TestEndToEnd(t *testing.T) { }, } s := NewServer("testServer", "v1.0.0", sopts) - add(tools, s.AddTools, "greet", "fail") - add(prompts, s.AddPrompts, "code_review", "fail") - add(resources, s.AddResources, "info.txt", "fail.txt") + AddTool(s, &Tool{ + Name: "greet", + Description: "say hi", + }, sayHi) + s.AddTool(&Tool{Name: "fail", InputSchema: &jsonschema.Schema{}}, + func(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { + return nil, errTestFailure + }) + s.AddPrompt(&Prompt{ + Name: "code_review", + Description: "do a code review", + Arguments: []*PromptArgument{{Name: "Code", Required: true}}, + }, func(_ context.Context, _ *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { + return &GetPromptResult{ + Description: "Code review prompt", + Messages: []*PromptMessage{ + {Role: "user", Content: &TextContent{Text: "Please review the following code: " + params.Arguments["Code"]}}, + }, + }, nil + }) + s.AddPrompt(&Prompt{Name: "fail"}, func(_ context.Context, _ *ServerSession, _ *GetPromptParams) (*GetPromptResult, error) { + return nil, errTestFailure + }) + s.AddResource(resource1, readHandler) + s.AddResource(resource2, readHandler) // Connect the server. ss, err := s.Connect(ctx, st) @@ -154,39 +178,14 @@ func TestEndToEnd(t *testing.T) { t.Errorf("fail returned unexpected error: got %v, want containing %v", err, errTestFailure) } - s.AddPrompts(&ServerPrompt{Prompt: &Prompt{Name: "T"}}) + s.AddPrompt(&Prompt{Name: "T"}, nil) waitForNotification(t, "prompts") s.RemovePrompts("T") waitForNotification(t, "prompts") }) t.Run("tools", func(t *testing.T) { - res, err := cs.ListTools(ctx, nil) - if err != nil { - t.Errorf("tools/list failed: %v", err) - } - wantTools := []*Tool{ - { - Name: "fail", - InputSchema: nil, - }, - { - Name: "greet", - Description: "say hi", - InputSchema: &jsonschema.Schema{ - Type: "object", - Required: []string{"Name"}, - Properties: map[string]*jsonschema.Schema{ - "Name": {Type: "string"}, - }, - AdditionalProperties: falseSchema(), - }, - }, - } - if diff := cmp.Diff(wantTools, res.Tools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Fatalf("tools/list mismatch (-want +got):\n%s", diff) - } - + // ListTools is tested in client_list_test.go. gotHi, err := cs.CallTool(ctx, &CallToolParams{ Name: "greet", Arguments: map[string]any{"name": "user"}, @@ -222,7 +221,7 @@ func TestEndToEnd(t *testing.T) { t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) } - s.AddTools(&ServerTool{Tool: &Tool{Name: "T"}, Handler: nopHandler}) + s.AddTool(&Tool{Name: "T", InputSchema: &jsonschema.Schema{}}, nopHandler) waitForNotification(t, "tools") s.RemoveTools("T") waitForNotification(t, "tools") @@ -246,8 +245,7 @@ func TestEndToEnd(t *testing.T) { MIMEType: "text/template", URITemplate: "file:///{+filename}", // the '+' means that filename can contain '/' } - st := &ServerResourceTemplate{ResourceTemplate: template, Handler: readHandler} - s.AddResourceTemplates(st) + s.AddResourceTemplate(template, readHandler) tres, err := cs.ListResourceTemplates(ctx, nil) if err != nil { t.Fatal(err) @@ -292,7 +290,7 @@ func TestEndToEnd(t *testing.T) { } } - s.AddResources(&ServerResource{Resource: &Resource{URI: "http://U"}}) + s.AddResource(&Resource{URI: "http://U"}, nil) waitForNotification(t, "resources") s.RemoveResources("http://U") waitForNotification(t, "resources") @@ -434,40 +432,6 @@ func TestEndToEnd(t *testing.T) { var ( errTestFailure = errors.New("mcp failure") - tools = map[string]*ServerTool{ - "greet": NewServerTool("greet", "say hi", sayHi), - "fail": { - Tool: &Tool{Name: "fail"}, - Handler: func(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { - return nil, errTestFailure - }, - }, - } - - prompts = map[string]*ServerPrompt{ - "code_review": { - Prompt: &Prompt{ - Name: "code_review", - Description: "do a code review", - Arguments: []*PromptArgument{{Name: "Code", Required: true}}, - }, - Handler: func(_ context.Context, _ *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { - return &GetPromptResult{ - Description: "Code review prompt", - Messages: []*PromptMessage{ - {Role: "user", Content: &TextContent{Text: "Please review the following code: " + params.Arguments["Code"]}}, - }, - }, nil - }, - }, - "fail": { - Prompt: &Prompt{Name: "fail"}, - Handler: func(_ context.Context, _ *ServerSession, _ *GetPromptParams) (*GetPromptResult, error) { - return nil, errTestFailure - }, - }, - } - resource1 = &Resource{ Name: "public", MIMEType: "text/plain", @@ -484,11 +448,6 @@ var ( URI: "embedded:info", } readHandler = fileResourceHandler("testdata/files") - resources = map[string]*ServerResource{ - "info.txt": {resource1, readHandler}, - "fail.txt": {resource2, readHandler}, - "info": {resource3, handleEmbeddedResource}, - } ) var embeddedResources = map[string]string{ @@ -540,21 +499,21 @@ func errorCode(err error) int64 { return -1 } -// basicConnection returns a new basic client-server connection configured with -// the provided tools. +// basicConnection returns a new basic client-server connection, with the server +// configured via the provided function. // // The caller should cancel either the client connection or server connection // when the connections are no longer needed. -func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerSession, *ClientSession) { +func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *ClientSession) { t.Helper() ctx := context.Background() ct, st := NewInMemoryTransports() s := NewServer("testServer", "v1.0.0", nil) - - // The 'greet' tool says hi. - s.AddTools(tools...) + if config != nil { + config(s) + } ss, err := s.Connect(ctx, st) if err != nil { t.Fatal(err) @@ -569,7 +528,9 @@ func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerSession, *Clien } func TestServerClosing(t *testing.T) { - cc, cs := basicConnection(t, NewServerTool("greet", "say hi", sayHi)) + cc, cs := basicConnection(t, func(s *Server) { + AddTool(s, greetTool(), sayHi) + }) defer cs.Close() ctx := context.Background() @@ -651,11 +612,9 @@ func TestCancellation(t *testing.T) { } return nil, nil } - st := &ServerTool{ - Tool: &Tool{Name: "slow"}, - Handler: slowRequest, - } - _, cs := basicConnection(t, st) + _, cs := basicConnection(t, func(s *Server) { + s.AddTool(&Tool{Name: "slow"}, slowRequest) + }) defer cs.Close() ctx, cancel := context.WithCancel(context.Background()) @@ -852,7 +811,7 @@ func TestKeepAlive(t *testing.T) { KeepAlive: 100 * time.Millisecond, } s := NewServer("testServer", "v1.0.0", serverOpts) - s.AddTools(NewServerTool("greet", "say hi", sayHi)) + AddTool(s, greetTool(), sayHi) ss, err := s.Connect(ctx, st) if err != nil { @@ -897,7 +856,7 @@ func TestKeepAliveFailure(t *testing.T) { // Server without keepalive (to test one-sided keepalive) s := NewServer("testServer", "v1.0.0", nil) - s.AddTools(NewServerTool("greet", "say hi", sayHi)) + AddTool(s, greetTool(), sayHi) ss, err := s.Connect(ctx, st) if err != nil { t.Fatal(err) diff --git a/mcp/prompt.go b/mcp/prompt.go index e2db7b27..0ecf5528 100644 --- a/mcp/prompt.go +++ b/mcp/prompt.go @@ -11,8 +11,7 @@ import ( // A PromptHandler handles a call to prompts/get. type PromptHandler func(context.Context, *ServerSession, *GetPromptParams) (*GetPromptResult, error) -// A Prompt is a prompt definition bound to a prompt handler. -type ServerPrompt struct { - Prompt *Prompt - Handler PromptHandler +type serverPrompt struct { + prompt *Prompt + handler PromptHandler } diff --git a/mcp/resource.go b/mcp/resource.go index 18e0bec4..4202fdac 100644 --- a/mcp/resource.go +++ b/mcp/resource.go @@ -20,16 +20,16 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/util" ) -// A ServerResource associates a Resource with its handler. -type ServerResource struct { - Resource *Resource - Handler ResourceHandler +// A serverResource associates a Resource with its handler. +type serverResource struct { + resource *Resource + handler ResourceHandler } -// A ServerResourceTemplate associates a ResourceTemplate with its handler. -type ServerResourceTemplate struct { - ResourceTemplate *ResourceTemplate - Handler ResourceHandler +// A serverResourceTemplate associates a ResourceTemplate with its handler. +type serverResourceTemplate struct { + resourceTemplate *ResourceTemplate + handler ResourceHandler } // A ResourceHandler is a function that reads a resource. @@ -156,8 +156,8 @@ func fileRoot(root *Root) (_ string, err error) { // Matches reports whether the receiver's uri template matches the uri. // TODO: use "github.com/yosida95/uritemplate/v3" -func (sr *ServerResourceTemplate) Matches(uri string) bool { - re, err := uriTemplateToRegexp(sr.ResourceTemplate.URITemplate) +func (sr *serverResourceTemplate) Matches(uri string) bool { + re, err := uriTemplateToRegexp(sr.resourceTemplate.URITemplate) if err != nil { return false } diff --git a/mcp/server.go b/mcp/server.go index 69666a6a..cd8f808b 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "iter" + "log" "net/url" "path/filepath" "slices" @@ -20,7 +21,6 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/util" - "github.com/modelcontextprotocol/go-sdk/jsonschema" ) const DefaultPageSize = 1000 @@ -36,10 +36,10 @@ type Server struct { opts ServerOptions mu sync.Mutex - prompts *featureSet[*ServerPrompt] - tools *featureSet[*ServerTool] - resources *featureSet[*ServerResource] - resourceTemplates *featureSet[*ServerResourceTemplate] + prompts *featureSet[*serverPrompt] + tools *featureSet[*serverTool] + resources *featureSet[*serverResource] + resourceTemplates *featureSet[*serverResourceTemplate] sessions []*ServerSession sendingMethodHandler_ MethodHandler[*ServerSession] receivingMethodHandler_ MethodHandler[*ServerSession] @@ -87,28 +87,23 @@ func NewServer(name, version string, opts *ServerOptions) *Server { name: name, version: version, opts: *opts, - prompts: newFeatureSet(func(p *ServerPrompt) string { return p.Prompt.Name }), - tools: newFeatureSet(func(t *ServerTool) string { return t.Tool.Name }), - resources: newFeatureSet(func(r *ServerResource) string { return r.Resource.URI }), - resourceTemplates: newFeatureSet(func(t *ServerResourceTemplate) string { return t.ResourceTemplate.URITemplate }), + prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), + tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), + resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }), + resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }), sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession], receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], } } -// AddPrompts adds the given prompts to the server, -// replacing any with the same names. -func (s *Server) AddPrompts(prompts ...*ServerPrompt) { - // Only notify if something could change. - if len(prompts) == 0 { - return - } - // Assume there was a change, since add replaces existing roots. - // (It's possible a root was replaced with an identical one, but not worth checking.) +// AddPrompt adds a [Prompt] to the server, or replaces one with the same name. +func (s *Server) AddPrompt(p *Prompt, h PromptHandler) { + // Assume there was a change, since add replaces existing items. + // (It's possible an item was replaced with an identical one, but not worth checking.) s.changeAndNotify( notificationPromptListChanged, &PromptListChangedParams{}, - func() bool { s.prompts.add(prompts...); return true }) + func() bool { s.prompts.add(&serverPrompt{p, h}); return true }) } // RemovePrompts removes the prompts with the given names. @@ -118,55 +113,44 @@ func (s *Server) RemovePrompts(names ...string) { func() bool { return s.prompts.remove(names...) }) } -// AddTools adds the given tools to the server, -// replacing any with the same names. -// The arguments must not be modified after this call. -// -// AddTools panics if errors are detected. -func (s *Server) AddTools(tools ...*ServerTool) { - if err := s.addToolsErr(tools...); err != nil { +// AddTool adds a [Tool] to the server, or replaces one with the same name. +// The tool's input schema must be non-nil. +// The Tool argument must not be modified after this call. +func (s *Server) AddTool(t *Tool, h ToolHandler) { + // TODO(jba): This is a breaking behavior change. Add before v0.2.0? + if t.InputSchema == nil { + log.Printf("mcp: tool %q has a nil input schema. This will panic in a future release.", t.Name) + // panic(fmt.Sprintf("adding tool %q: nil input schema", t.Name)) + } + if err := addToolErr(s, t, h); err != nil { panic(err) } } -// addToolsErr is like [AddTools], but returns an error instead of panicking. -func (s *Server) addToolsErr(tools ...*ServerTool) error { - // Only notify if something could change. - if len(tools) == 0 { - return nil +// AddTool adds a [Tool] to the server, or replaces one with the same name. +// If the tool's input schema is nil, it is set to the schema inferred from the In +// type parameter, using [jsonschema.For]. +// If the tool's output schema is nil and the Out type parameter is not the empty +// interface, then the output schema is set to the schema inferred from Out. +// The Tool argument must not be modified after this call. +func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { + if err := addToolErr(s, t, h); err != nil { + panic(err) } - // Wrap the user's Handlers with rawHandlers that take a json.RawMessage. - for _, st := range tools { - if st.rawHandler == nil { - // This ServerTool was not created with NewServerTool. - if st.Handler == nil { - return fmt.Errorf("AddTools: tool %q has no handler", st.Tool.Name) - } - st.rawHandler = newRawHandler(st) - // Resolve the schemas, with no base URI. We don't expect tool schemas to - // refer outside of themselves. - if st.Tool.InputSchema != nil { - r, err := st.Tool.InputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) - if err != nil { - return err - } - st.inputResolved = r - } +} - // if st.Tool.OutputSchema != nil { - // st.outputResolved, err := st.Tool.OutputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) - // if err != nil { - // return err - // } - // } - } +func addToolErr[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) (err error) { + defer util.Wrapf(&err, "adding tool %q", t.Name) + st, err := newServerTool(t, h) + if err != nil { + return err } - // Assume there was a change, since add replaces existing tools. // (It's possible a tool was replaced with an identical one, but not worth checking.) - // TODO: surface notify error here? + // TODO: Batch these changes by size and time? The typescript SDK doesn't. + // TODO: Surface notify error here? best not, in case we need to batch. s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{}, - func() bool { s.tools.add(tools...); return true }) + func() bool { s.tools.add(st); return true }) return nil } @@ -177,26 +161,19 @@ func (s *Server) RemoveTools(names ...string) { func() bool { return s.tools.remove(names...) }) } -// AddResources adds the given resources to the server. -// If a resource with the same URI already exists, it is replaced. -// AddResources panics if a resource URI is invalid or not absolute (has an empty scheme). -func (s *Server) AddResources(resources ...*ServerResource) { - // Only notify if something could change. - if len(resources) == 0 { - return - } +// AddResource adds a [Resource] to the server, or replaces one with the same URI. +// AddResource panics if the resource URI is invalid or not absolute (has an empty scheme). +func (s *Server) AddResource(r *Resource, h ResourceHandler) { s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, func() bool { - for _, r := range resources { - u, err := url.Parse(r.Resource.URI) - if err != nil { - panic(err) // url.Parse includes the URI in the error - } - if !u.IsAbs() { - panic(fmt.Errorf("URI %s needs a scheme", r.Resource.URI)) - } - s.resources.add(r) + u, err := url.Parse(r.URI) + if err != nil { + panic(err) // url.Parse includes the URI in the error + } + if !u.IsAbs() { + panic(fmt.Errorf("URI %s needs a scheme", r.URI)) } + s.resources.add(&serverResource{r, h}) return true }) } @@ -208,20 +185,13 @@ func (s *Server) RemoveResources(uris ...string) { func() bool { return s.resources.remove(uris...) }) } -// AddResourceTemplates adds the given resource templates to the server. -// If a resource template with the same URI template already exists, it will be replaced. -// AddResourceTemplates panics if a URI template is invalid or not absolute (has an empty scheme). -func (s *Server) AddResourceTemplates(templates ...*ServerResourceTemplate) { - // Only notify if something could change. - if len(templates) == 0 { - return - } +// AddResourceTemplate adds a [ResourceTemplate] to the server, or replaces on with the same URI. +// AddResourceTemplate panics if a URI template is invalid or not absolute (has an empty scheme). +func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) { s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, func() bool { - for _, t := range templates { - // TODO: check template validity. - s.resourceTemplates.add(t) - } + // TODO: check template validity. + s.resourceTemplates.add(&serverResourceTemplate{t, h}) return true }) } @@ -268,10 +238,10 @@ func (s *Server) listPrompts(_ context.Context, _ *ServerSession, params *ListPr if params == nil { params = &ListPromptsParams{} } - return paginateList(s.prompts, s.opts.PageSize, params, &ListPromptsResult{}, func(res *ListPromptsResult, prompts []*ServerPrompt) { + return paginateList(s.prompts, s.opts.PageSize, params, &ListPromptsResult{}, func(res *ListPromptsResult, prompts []*serverPrompt) { res.Prompts = []*Prompt{} // avoid JSON null for _, p := range prompts { - res.Prompts = append(res.Prompts, p.Prompt) + res.Prompts = append(res.Prompts, p.prompt) } }) } @@ -284,7 +254,7 @@ func (s *Server) getPrompt(ctx context.Context, cc *ServerSession, params *GetPr // TODO: surface the error code over the wire, instead of flattening it into the string. return nil, fmt.Errorf("%s: unknown prompt %q", jsonrpc2.ErrInvalidParams, params.Name) } - return prompt.Handler(ctx, cc, params) + return prompt.handler(ctx, cc, params) } func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListToolsParams) (*ListToolsResult, error) { @@ -293,22 +263,22 @@ func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListTool if params == nil { params = &ListToolsParams{} } - return paginateList(s.tools, s.opts.PageSize, params, &ListToolsResult{}, func(res *ListToolsResult, tools []*ServerTool) { + return paginateList(s.tools, s.opts.PageSize, params, &ListToolsResult{}, func(res *ListToolsResult, tools []*serverTool) { res.Tools = []*Tool{} // avoid JSON null for _, t := range tools { - res.Tools = append(res.Tools, t.Tool) + res.Tools = append(res.Tools, t.tool) } }) } func (s *Server) callTool(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) { s.mu.Lock() - tool, ok := s.tools.get(params.Name) + st, ok := s.tools.get(params.Name) s.mu.Unlock() if !ok { return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, params.Name) } - return tool.rawHandler(ctx, cc, params) + return st.handler(ctx, cc, params) } func (s *Server) listResources(_ context.Context, _ *ServerSession, params *ListResourcesParams) (*ListResourcesResult, error) { @@ -317,10 +287,10 @@ func (s *Server) listResources(_ context.Context, _ *ServerSession, params *List if params == nil { params = &ListResourcesParams{} } - return paginateList(s.resources, s.opts.PageSize, params, &ListResourcesResult{}, func(res *ListResourcesResult, resources []*ServerResource) { + return paginateList(s.resources, s.opts.PageSize, params, &ListResourcesResult{}, func(res *ListResourcesResult, resources []*serverResource) { res.Resources = []*Resource{} // avoid JSON null for _, r := range resources { - res.Resources = append(res.Resources, r.Resource) + res.Resources = append(res.Resources, r.resource) } }) } @@ -332,10 +302,10 @@ func (s *Server) listResourceTemplates(_ context.Context, _ *ServerSession, para params = &ListResourceTemplatesParams{} } return paginateList(s.resourceTemplates, s.opts.PageSize, params, &ListResourceTemplatesResult{}, - func(res *ListResourceTemplatesResult, rts []*ServerResourceTemplate) { + func(res *ListResourceTemplatesResult, rts []*serverResourceTemplate) { res.ResourceTemplates = []*ResourceTemplate{} // avoid JSON null for _, rt := range rts { - res.ResourceTemplates = append(res.ResourceTemplates, rt.ResourceTemplate) + res.ResourceTemplates = append(res.ResourceTemplates, rt.resourceTemplate) } }) } @@ -376,12 +346,12 @@ func (s *Server) lookupResourceHandler(uri string) (ResourceHandler, string, boo defer s.mu.Unlock() // Try resources first. if r, ok := s.resources.get(uri); ok { - return r.Handler, r.Resource.MIMEType, true + return r.handler, r.resource.MIMEType, true } // Look for matching template. for rt := range s.resourceTemplates.all() { if rt.Matches(uri) { - return rt.Handler, rt.ResourceTemplate.MIMEType, true + return rt.handler, rt.resourceTemplate.MIMEType, true } } return nil, "", false diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index 9e982374..fd6eea00 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -29,7 +29,7 @@ func ExampleServer() { clientTransport, serverTransport := mcp.NewInMemoryTransports() server := mcp.NewServer("greeter", "v0.0.1", nil) - server.AddTools(mcp.NewServerTool("greet", "say hi", SayHi)) + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) serverSession, err := server.Connect(ctx, serverTransport) if err != nil { diff --git a/mcp/shared_test.go b/mcp/shared_test.go index 5a1d5d02..f319d80e 100644 --- a/mcp/shared_test.go +++ b/mcp/shared_test.go @@ -12,7 +12,7 @@ import ( ) // TODO(jba): this shouldn't be in this file, but tool_test.go doesn't have access to unexported symbols. -func TestNewServerToolValidate(t *testing.T) { +func TestToolValidate(t *testing.T) { // Check that the tool returned from NewServerTool properly validates its input schema. type req struct { @@ -26,9 +26,10 @@ func TestNewServerToolValidate(t *testing.T) { return nil, nil } - tool := NewServerTool("test", "test", dummyHandler) - // Need to add the tool to a server to get resolved schemas. - // s := NewServer("", "", nil) + st, err := newServerTool(&Tool{Name: "test", Description: "test"}, dummyHandler) + if err != nil { + t.Fatal(err) + } for _, tt := range []struct { desc string @@ -71,7 +72,7 @@ func TestNewServerToolValidate(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = tool.rawHandler(context.Background(), nil, + _, err = st.handler(context.Background(), nil, &CallToolParamsFor[json.RawMessage]{Arguments: json.RawMessage(raw)}) if err == nil && tt.want != "" { t.Error("got success, wanted failure") diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index 70f84c3e..816e0134 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -28,7 +28,7 @@ func Add(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsF func ExampleSSEHandler() { server := mcp.NewServer("adder", "v0.0.1", nil) - server.AddTools(mcp.NewServerTool("add", "add two numbers", Add)) + mcp.AddTool(server, &mcp.Tool{Name: "add", Description: "add two numbers"}, Add) handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { return server }) httpServer := httptest.NewServer(handler) diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 23621931..e1df9536 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -20,7 +20,7 @@ func TestSSEServer(t *testing.T) { t.Run(fmt.Sprintf("closeServerFirst=%t", closeServerFirst), func(t *testing.T) { ctx := context.Background() server := NewServer("testServer", "v1.0.0", nil) - server.AddTools(NewServerTool("greet", "say hi", sayHi)) + AddTool(server, &Tool{Name: "greet"}, sayHi) sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index a8c916e8..8925b3da 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -32,8 +32,7 @@ func TestStreamableTransports(t *testing.T) { // 1. Create a server with a simple "greet" tool. server := NewServer("testServer", "v1.0.0", nil) - server.AddTools(NewServerTool("greet", "say hi", sayHi)) - + AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a // cookie-checking middleware. handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) @@ -323,13 +322,12 @@ func TestStreamableServerTransport(t *testing.T) { // Create a server containing a single tool, which runs the test tool // behavior, if any. server := NewServer("testServer", "v1.0.0", nil) - tool := NewServerTool("tool", "test tool", func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[any], error) { + AddTool(server, &Tool{Name: "tool"}, func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[any], error) { if test.tool != nil { test.tool(t, ctx, ss) } return &CallToolResultFor[any]{}, nil }) - server.AddTools(tool) // Start the streamable handler. handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) diff --git a/mcp/tool.go b/mcp/tool.go index a6f228eb..fc154991 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -9,7 +9,7 @@ import ( "context" "encoding/json" "fmt" - "slices" + "reflect" "github.com/modelcontextprotocol/go-sdk/jsonschema" ) @@ -17,8 +17,7 @@ import ( // A ToolHandler handles a call to tools/call. // [CallToolParams.Arguments] will contain a map[string]any that has been validated // against the input schema. -// TODO: Perhaps this should be an alias for ToolHandlerFor[map[string]any, map[string]any]? -type ToolHandler func(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) +type ToolHandler = ToolHandlerFor[map[string]any, any] // A ToolHandlerFor handles a call to tools/call with typed arguments and results. type ToolHandlerFor[In, Out any] func(context.Context, *ServerSession, *CallToolParamsFor[In]) (*CallToolResultFor[Out], error) @@ -26,62 +25,33 @@ type ToolHandlerFor[In, Out any] func(context.Context, *ServerSession, *CallTool // A rawToolHandler is like a ToolHandler, but takes the arguments as as json.RawMessage. type rawToolHandler = func(context.Context, *ServerSession, *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) -// A ServerTool is a tool definition that is bound to a tool handler. -type ServerTool struct { - Tool *Tool - Handler ToolHandler - // Set in NewServerTool or Server.addToolsErr. - rawHandler rawToolHandler - // Resolved tool schemas. Set in Server.addToolsErr. +// A serverTool is a tool definition that is bound to a tool handler. +type serverTool struct { + tool *Tool + handler rawToolHandler + // Resolved tool schemas. Set in newServerTool. inputResolved, outputResolved *jsonschema.Resolved } -// NewServerTool is a helper to make a tool using reflection on the given type parameters. -// When the tool is called, CallToolParams.Arguments will be of type In. -// -// If provided, variadic [ToolOption] values may be used to customize the tool. -// -// The input schema for the tool is extracted from the request type for the -// handler, and used to unmmarshal and validate requests to the handler. This -// schema may be customized using the [Input] option. -// -// TODO(jba): check that structured content is set in response. -func NewServerTool[In, Out any](name, description string, handler ToolHandlerFor[In, Out], opts ...ToolOption) *ServerTool { - st, err := newServerToolErr[In, Out](name, description, handler, opts...) - if err != nil { - panic(fmt.Errorf("NewServerTool(%q): %w", name, err)) - } - return st -} +// newServerTool creates a serverTool from a tool and a handler. +// If the tool doesn't have an input schema, it is inferred from In. +// If the tool doesn't have an output schema and Out != any, it is inferred from Out. +func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool, error) { + st := &serverTool{tool: t} -func newServerToolErr[In, Out any](name, description string, handler ToolHandlerFor[In, Out], opts ...ToolOption) (*ServerTool, error) { - // TODO: check that In is a struct. - ischema, err := jsonschema.For[In]() - if err != nil { + if err := setSchema[In](&t.InputSchema, &st.inputResolved); err != nil { return nil, err } - // TODO: uncomment when output schemas drop. - // oschema, err := jsonschema.For[TRes]() - // if err != nil { - // return nil, err - // } - - t := &ServerTool{ - Tool: &Tool{ - Name: name, - Description: description, - InputSchema: ischema, - // OutputSchema: oschema, - }, - } - for _, opt := range opts { - opt.set(t) + if reflect.TypeFor[Out]() != reflect.TypeFor[any]() { + if err := setSchema[Out](&t.OutputSchema, &st.outputResolved); err != nil { + return nil, err + } } - t.rawHandler = func(ctx context.Context, ss *ServerSession, rparams *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) { + st.handler = func(ctx context.Context, ss *ServerSession, rparams *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) { var args In if rparams.Arguments != nil { - if err := unmarshalSchema(rparams.Arguments, t.inputResolved, &args); err != nil { + if err := unmarshalSchema(rparams.Arguments, st.inputResolved, &args); err != nil { return nil, err } } @@ -91,55 +61,41 @@ func newServerToolErr[In, Out any](name, description string, handler ToolHandler Name: rparams.Name, Arguments: args, } - res, err := handler(ctx, ss, params) + res, err := h(ctx, ss, params) + // TODO(rfindley): investigate why server errors are embedded in this strange way, + // rather than returned as jsonrpc2 server errors. if err != nil { - return nil, err + return &CallToolResult{ + Content: []Content{&TextContent{Text: err.Error()}}, + IsError: true, + }, nil } - var ctr CallToolResult + // TODO(jba): What if res == nil? Is that valid? + // TODO(jba): if t.OutputSchema != nil, check that StructuredContent is present and validates. if res != nil { // TODO(jba): future-proof this copy. ctr.Meta = res.Meta ctr.Content = res.Content ctr.IsError = res.IsError + ctr.StructuredContent = res.StructuredContent } return &ctr, nil } - return t, nil + + return st, nil } -// newRawHandler creates a rawToolHandler for tools not created through NewServerTool. -// It unmarshals the arguments into a map[string]any and validates them against the -// schema, then calls the ServerTool's handler. -func newRawHandler(st *ServerTool) rawToolHandler { - if st.Handler == nil { - panic("st.Handler is nil") +func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) error { + var err error + if *sfield == nil { + *sfield, err = jsonschema.For[T]() } - return func(ctx context.Context, ss *ServerSession, rparams *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) { - // Unmarshal the args into what should be a map. - var args map[string]any - if rparams.Arguments != nil { - if err := unmarshalSchema(rparams.Arguments, st.inputResolved, &args); err != nil { - return nil, err - } - } - // TODO: generate copy - params := &CallToolParamsFor[map[string]any]{ - Meta: rparams.Meta, - Name: rparams.Name, - Arguments: args, - } - res, err := st.Handler(ctx, ss, params) - // TODO(rfindley): investigate why server errors are embedded in this strange way, - // rather than returned as jsonrpc2 server errors. - if err != nil { - return &CallToolResult{ - Content: []Content{&TextContent{Text: err.Error()}}, - IsError: true, - }, nil - } - return res, nil + if err != nil { + return err } + *rfield, err = (*sfield).Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) + return err } // unmarshalSchema unmarshals data into v and validates the result according to @@ -169,105 +125,6 @@ func unmarshalSchema(data json.RawMessage, resolved *jsonschema.Resolved, v any) return nil } -// A ToolOption configures the behavior of a Tool. -type ToolOption interface { - set(*ServerTool) -} - -type toolSetter func(*ServerTool) - -func (s toolSetter) set(t *ServerTool) { s(t) } - -// Input applies the provided [SchemaOption] configuration to the tool's input -// schema. -func Input(opts ...SchemaOption) ToolOption { - return toolSetter(func(t *ServerTool) { - for _, opt := range opts { - opt.set(t.Tool.InputSchema) - } - }) -} - -// A SchemaOption configures a jsonschema.Schema. -type SchemaOption interface { - set(s *jsonschema.Schema) -} - -type schemaSetter func(*jsonschema.Schema) - -func (s schemaSetter) set(schema *jsonschema.Schema) { s(schema) } - -// Property configures the schema for the property of the given name. -// If there is no such property in the schema, it is created. -func Property(name string, opts ...SchemaOption) SchemaOption { - return schemaSetter(func(schema *jsonschema.Schema) { - propSchema, ok := schema.Properties[name] - if !ok { - propSchema = new(jsonschema.Schema) - schema.Properties[name] = propSchema - } - // Apply the options, with special handling for Required, as it needs to be - // set on the parent schema. - for _, opt := range opts { - if req, ok := opt.(required); ok { - if req { - if !slices.Contains(schema.Required, name) { - schema.Required = append(schema.Required, name) - } - } else { - schema.Required = slices.DeleteFunc(schema.Required, func(s string) bool { - return s == name - }) - } - } else { - opt.set(propSchema) - } - } - }) -} - -// Required sets whether the associated property is required. It is only valid -// when used in a [Property] option: using Required outside of Property panics. -func Required(v bool) SchemaOption { - return required(v) -} - -// required must be a distinguished type as it needs special handling to mutate -// the parent schema, and to mutate prompt arguments. -type required bool - -func (required) set(s *jsonschema.Schema) { - panic("use of required outside of Property") -} - -// Enum sets the provided values as the "enum" value of the schema. -func Enum(values ...any) SchemaOption { - return schemaSetter(func(s *jsonschema.Schema) { - s.Enum = values - }) -} - -// Description sets the provided schema description. -func Description(desc string) SchemaOption { - return description(desc) -} - -// description must be a distinguished type so that it can be handled by prompt -// options. -type description string - -func (d description) set(s *jsonschema.Schema) { - s.Description = string(d) -} - -// Schema overrides the inferred schema with a shallow copy of the given -// schema. -func Schema(schema *jsonschema.Schema) SchemaOption { - return schemaSetter(func(s *jsonschema.Schema) { - *s = *schema - }) -} - // schemaJSON returns the JSON value for s as a string, or a string indicating an error. func schemaJSON(s *jsonschema.Schema) string { m, err := json.Marshal(s) diff --git a/mcp/tool_test.go b/mcp/tool_test.go index 85775e9b..4d0a329b 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -16,76 +16,80 @@ import ( ) // testToolHandler is used for type inference in TestNewServerTool. -func testToolHandler[T any](context.Context, *ServerSession, *CallToolParamsFor[T]) (*CallToolResultFor[any], error) { +func testToolHandler[In, Out any](context.Context, *ServerSession, *CallToolParamsFor[In]) (*CallToolResultFor[Out], error) { panic("not implemented") } +func srvTool[In, Out any](t *testing.T, tool *Tool, handler ToolHandlerFor[In, Out]) *serverTool { + t.Helper() + st, err := newServerTool(tool, handler) + if err != nil { + t.Fatal(err) + } + return st +} + func TestNewServerTool(t *testing.T) { + type ( + Name struct { + Name string `json:"name"` + } + Size struct { + Size int `json:"size"` + } + ) + + nameSchema := &jsonschema.Schema{ + Type: "object", + Required: []string{"name"}, + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string"}, + }, + AdditionalProperties: &jsonschema.Schema{Not: new(jsonschema.Schema)}, + } + sizeSchema := &jsonschema.Schema{ + Type: "object", + Required: []string{"size"}, + Properties: map[string]*jsonschema.Schema{ + "size": {Type: "integer"}, + }, + AdditionalProperties: &jsonschema.Schema{Not: new(jsonschema.Schema)}, + } + tests := []struct { - tool *ServerTool - want *jsonschema.Schema + tool *serverTool + wantIn, wantOut *jsonschema.Schema }{ { - NewServerTool("basic", "", testToolHandler[struct { - Name string `json:"name"` - }]), - &jsonschema.Schema{ - Type: "object", - Required: []string{"name"}, - Properties: map[string]*jsonschema.Schema{ - "name": {Type: "string"}, - }, - AdditionalProperties: &jsonschema.Schema{Not: new(jsonschema.Schema)}, - }, + srvTool(t, &Tool{Name: "basic"}, testToolHandler[Name, Size]), + nameSchema, + sizeSchema, }, { - NewServerTool("enum", "", testToolHandler[struct{ Name string }], Input( - Property("Name", Enum("x", "y", "z")), - )), - &jsonschema.Schema{ - Type: "object", - Required: []string{"Name"}, - Properties: map[string]*jsonschema.Schema{ - "Name": {Type: "string", Enum: []any{"x", "y", "z"}}, - }, - AdditionalProperties: &jsonschema.Schema{Not: new(jsonschema.Schema)}, - }, + srvTool(t, &Tool{ + Name: "in untouched", + InputSchema: &jsonschema.Schema{}, + }, testToolHandler[Name, Size]), + &jsonschema.Schema{}, + sizeSchema, }, { - NewServerTool("required", "", testToolHandler[struct { - Name string `json:"name"` - Language string `json:"language"` - X int `json:"x,omitempty"` - Y int `json:"y,omitempty"` - }], Input( - Property("x", Required(true)))), - &jsonschema.Schema{ - Type: "object", - Required: []string{"name", "language", "x"}, - Properties: map[string]*jsonschema.Schema{ - "language": {Type: "string"}, - "name": {Type: "string"}, - "x": {Type: "integer"}, - "y": {Type: "integer"}, - }, - AdditionalProperties: &jsonschema.Schema{Not: new(jsonschema.Schema)}, - }, + srvTool(t, &Tool{Name: "out untouched", OutputSchema: &jsonschema.Schema{}}, testToolHandler[Name, Size]), + nameSchema, + &jsonschema.Schema{}, }, { - NewServerTool("set_schema", "", testToolHandler[struct { - X int `json:"x,omitempty"` - Y int `json:"y,omitempty"` - }], Input( - Schema(&jsonschema.Schema{Type: "object"})), - ), - &jsonschema.Schema{ - Type: "object", - }, + srvTool(t, &Tool{Name: "nil out"}, testToolHandler[Name, any]), + nameSchema, + nil, }, } for _, test := range tests { - if diff := cmp.Diff(test.want, test.tool.Tool.InputSchema, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Errorf("NewServerTool(%v) mismatch (-want +got):\n%s", test.tool.Tool.Name, diff) + if diff := cmp.Diff(test.wantIn, test.tool.tool.InputSchema, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Errorf("newServerTool(%q) input schema mismatch (-want +got):\n%s", test.tool.tool.Name, diff) + } + if diff := cmp.Diff(test.wantOut, test.tool.tool.OutputSchema, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Errorf("newServerTool(%q) output schema mismatch (-want +got):\n%s", test.tool.tool.Name, diff) } } } From 923c2d7a7ed71d077fb7b63ebbe5096996edc6fb Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 8 Jul 2025 14:08:01 -0400 Subject: [PATCH 006/221] testdata: explain omitted test files (#106) Explain why certain files are omitted from our tests. --- jsonschema/testdata/draft2020-12/README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/jsonschema/testdata/draft2020-12/README.md b/jsonschema/testdata/draft2020-12/README.md index 09ae5704..dbc397dd 100644 --- a/jsonschema/testdata/draft2020-12/README.md +++ b/jsonschema/testdata/draft2020-12/README.md @@ -2,3 +2,14 @@ These files were copied from https://github.com/json-schema-org/JSON-Schema-Test-Suite/tree/83e866b46c9f9e7082fd51e83a61c5f2145a1ab7/tests/draft2020-12. + +The following files were omitted: + +content.json: it is not required to validate content fields +(https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.8.1). + +format.json: it is not required to validate format fields (https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.7.1). + +vocabulary.json: this package doesn't support explicit vocabularies, other than the 2020-12 draft. + +The "optional" directory: this package doesn't implement any optional features. From cdf34cc9da2ad23722af8725f4fe097ed3006667 Mon Sep 17 00:00:00 2001 From: Gerard Adam Date: Wed, 9 Jul 2025 02:46:47 +0700 Subject: [PATCH 007/221] mcp: memory server example (#90) This PR adds an example of implementing go-sdk on [memory server](https://github.com/modelcontextprotocol/servers/tree/main/src/memory). For #33. --- examples/memory/kb.go | 582 ++++++++++++++++++++++++++++++ examples/memory/kb_test.go | 703 +++++++++++++++++++++++++++++++++++++ examples/memory/main.go | 145 ++++++++ 3 files changed, 1430 insertions(+) create mode 100644 examples/memory/kb.go create mode 100644 examples/memory/kb_test.go create mode 100644 examples/memory/main.go diff --git a/examples/memory/kb.go b/examples/memory/kb.go new file mode 100644 index 00000000..a274f057 --- /dev/null +++ b/examples/memory/kb.go @@ -0,0 +1,582 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "slices" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// Entity represents a knowledge graph node with observations. +type Entity struct { + Name string `json:"name"` + EntityType string `json:"entityType"` + Observations []string `json:"observations"` +} + +// Relation represents a directed edge between two entities. +type Relation struct { + From string `json:"from"` + To string `json:"to"` + RelationType string `json:"relationType"` +} + +// Observation contains facts about an entity. +type Observation struct { + EntityName string `json:"entityName"` + Contents []string `json:"contents"` + + Observations []string `json:"observations,omitempty"` // Used for deletion operations +} + +// KnowledgeGraph represents the complete graph structure. +type KnowledgeGraph struct { + Entities []Entity `json:"entities"` + Relations []Relation `json:"relations"` +} + +// store provides persistence interface for knowledge base data. +type store interface { + Read() ([]byte, error) + Write(data []byte) error +} + +// memoryStore implements in-memory storage that doesn't persist across restarts. +type memoryStore struct { + data []byte +} + +// Read returns the in-memory data. +func (ms *memoryStore) Read() ([]byte, error) { + return ms.data, nil +} + +// Write stores data in memory. +func (ms *memoryStore) Write(data []byte) error { + ms.data = data + return nil +} + +// fileStore implements file-based storage for persistent knowledge base. +type fileStore struct { + path string +} + +// Read loads data from file, returning empty slice if file doesn't exist. +func (fs *fileStore) Read() ([]byte, error) { + data, err := os.ReadFile(fs.path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to read file %s: %w", fs.path, err) + } + return data, nil +} + +// Write saves data to file with 0600 permissions. +func (fs *fileStore) Write(data []byte) error { + if err := os.WriteFile(fs.path, data, 0600); err != nil { + return fmt.Errorf("failed to write file %s: %w", fs.path, err) + } + return nil +} + +// knowledgeBase manages entities and relations with persistent storage. +type knowledgeBase struct { + s store +} + +// kbItem represents a single item in persistent storage (entity or relation). +type kbItem struct { + Type string `json:"type"` + + // Entity fields (when Type == "entity") + Name string `json:"name,omitempty"` + EntityType string `json:"entityType,omitempty"` + Observations []string `json:"observations,omitempty"` + + // Relation fields (when Type == "relation") + From string `json:"from,omitempty"` + To string `json:"to,omitempty"` + RelationType string `json:"relationType,omitempty"` +} + +// loadGraph deserializes the knowledge graph from storage. +func (k knowledgeBase) loadGraph() (KnowledgeGraph, error) { + data, err := k.s.Read() + if err != nil { + return KnowledgeGraph{}, fmt.Errorf("failed to read from store: %w", err) + } + + if len(data) == 0 { + return KnowledgeGraph{}, nil + } + + var items []kbItem + if err := json.Unmarshal(data, &items); err != nil { + return KnowledgeGraph{}, fmt.Errorf("failed to unmarshal from store: %w", err) + } + + graph := KnowledgeGraph{} + + for _, item := range items { + switch item.Type { + case "entity": + graph.Entities = append(graph.Entities, Entity{ + Name: item.Name, + EntityType: item.EntityType, + Observations: item.Observations, + }) + case "relation": + graph.Relations = append(graph.Relations, Relation{ + From: item.From, + To: item.To, + RelationType: item.RelationType, + }) + } + } + + return graph, nil +} + +// saveGraph serializes and persists the knowledge graph to storage. +func (k knowledgeBase) saveGraph(graph KnowledgeGraph) error { + items := make([]kbItem, 0, len(graph.Entities)+len(graph.Relations)) + + for _, entity := range graph.Entities { + items = append(items, kbItem{ + Type: "entity", + Name: entity.Name, + EntityType: entity.EntityType, + Observations: entity.Observations, + }) + } + + for _, relation := range graph.Relations { + items = append(items, kbItem{ + Type: "relation", + From: relation.From, + To: relation.To, + RelationType: relation.RelationType, + }) + } + + itemsJSON, err := json.Marshal(items) + if err != nil { + return fmt.Errorf("failed to marshal items: %w", err) + } + + if err := k.s.Write(itemsJSON); err != nil { + return fmt.Errorf("failed to write to store: %w", err) + } + return nil +} + +// createEntities adds new entities to the graph, skipping duplicates by name. +// It returns the new entities that were actually added. +func (k knowledgeBase) createEntities(entities []Entity) ([]Entity, error) { + graph, err := k.loadGraph() + if err != nil { + return nil, err + } + + var newEntities []Entity + for _, entity := range entities { + if !slices.ContainsFunc(graph.Entities, func(e Entity) bool { return e.Name == entity.Name }) { + newEntities = append(newEntities, entity) + graph.Entities = append(graph.Entities, entity) + } + } + + if err := k.saveGraph(graph); err != nil { + return nil, err + } + + return newEntities, nil +} + +// createRelations adds new relations to the graph, skipping exact duplicates. +// It returns the new relations that were actually added. +func (k knowledgeBase) createRelations(relations []Relation) ([]Relation, error) { + graph, err := k.loadGraph() + if err != nil { + return nil, err + } + + var newRelations []Relation + for _, relation := range relations { + exists := slices.ContainsFunc(graph.Relations, func(r Relation) bool { + return r.From == relation.From && + r.To == relation.To && + r.RelationType == relation.RelationType + }) + if !exists { + newRelations = append(newRelations, relation) + graph.Relations = append(graph.Relations, relation) + } + } + + if err := k.saveGraph(graph); err != nil { + return nil, err + } + + return newRelations, nil +} + +// addObservations appends new observations to existing entities. +// It returns the new observations that were actually added. +func (k knowledgeBase) addObservations(observations []Observation) ([]Observation, error) { + graph, err := k.loadGraph() + if err != nil { + return nil, err + } + + var results []Observation + + for _, obs := range observations { + entityIndex := slices.IndexFunc(graph.Entities, func(e Entity) bool { return e.Name == obs.EntityName }) + if entityIndex == -1 { + return nil, fmt.Errorf("entity with name %s not found", obs.EntityName) + } + + var newObservations []string + for _, content := range obs.Contents { + if !slices.Contains(graph.Entities[entityIndex].Observations, content) { + newObservations = append(newObservations, content) + graph.Entities[entityIndex].Observations = append(graph.Entities[entityIndex].Observations, content) + } + } + + results = append(results, Observation{ + EntityName: obs.EntityName, + Contents: newObservations, + }) + } + + if err := k.saveGraph(graph); err != nil { + return nil, err + } + + return results, nil +} + +// deleteEntities removes entities and their associated relations. +func (k knowledgeBase) deleteEntities(entityNames []string) error { + graph, err := k.loadGraph() + if err != nil { + return err + } + + // Create map for quick lookup + entitiesToDelete := make(map[string]bool) + for _, name := range entityNames { + entitiesToDelete[name] = true + } + + // Filter entities using slices.DeleteFunc + graph.Entities = slices.DeleteFunc(graph.Entities, func(entity Entity) bool { + return entitiesToDelete[entity.Name] + }) + + // Filter relations using slices.DeleteFunc + graph.Relations = slices.DeleteFunc(graph.Relations, func(relation Relation) bool { + return entitiesToDelete[relation.From] || entitiesToDelete[relation.To] + }) + + return k.saveGraph(graph) +} + +// deleteObservations removes specific observations from entities. +func (k knowledgeBase) deleteObservations(deletions []Observation) error { + graph, err := k.loadGraph() + if err != nil { + return err + } + + for _, deletion := range deletions { + entityIndex := slices.IndexFunc(graph.Entities, func(e Entity) bool { + return e.Name == deletion.EntityName + }) + if entityIndex == -1 { + continue + } + + // Create a map for quick lookup + observationsToDelete := make(map[string]bool) + for _, observation := range deletion.Observations { + observationsToDelete[observation] = true + } + + // Filter observations using slices.DeleteFunc + graph.Entities[entityIndex].Observations = slices.DeleteFunc(graph.Entities[entityIndex].Observations, func(observation string) bool { + return observationsToDelete[observation] + }) + } + + return k.saveGraph(graph) +} + +// deleteRelations removes specific relations from the graph. +func (k knowledgeBase) deleteRelations(relations []Relation) error { + graph, err := k.loadGraph() + if err != nil { + return err + } + + // Filter relations using slices.DeleteFunc and slices.ContainsFunc + graph.Relations = slices.DeleteFunc(graph.Relations, func(existingRelation Relation) bool { + return slices.ContainsFunc(relations, func(relationToDelete Relation) bool { + return existingRelation.From == relationToDelete.From && + existingRelation.To == relationToDelete.To && + existingRelation.RelationType == relationToDelete.RelationType + }) + }) + return k.saveGraph(graph) +} + +// searchNodes filters entities and relations matching the query string. +func (k knowledgeBase) searchNodes(query string) (KnowledgeGraph, error) { + graph, err := k.loadGraph() + if err != nil { + return KnowledgeGraph{}, err + } + + queryLower := strings.ToLower(query) + var filteredEntities []Entity + + // Filter entities + for _, entity := range graph.Entities { + if strings.Contains(strings.ToLower(entity.Name), queryLower) || + strings.Contains(strings.ToLower(entity.EntityType), queryLower) { + filteredEntities = append(filteredEntities, entity) + continue + } + + // Check observations + for _, observation := range entity.Observations { + if strings.Contains(strings.ToLower(observation), queryLower) { + filteredEntities = append(filteredEntities, entity) + break + } + } + } + + // Create map for quick entity lookup + filteredEntityNames := make(map[string]bool) + for _, entity := range filteredEntities { + filteredEntityNames[entity.Name] = true + } + + // Filter relations + var filteredRelations []Relation + for _, relation := range graph.Relations { + if filteredEntityNames[relation.From] && filteredEntityNames[relation.To] { + filteredRelations = append(filteredRelations, relation) + } + } + + return KnowledgeGraph{ + Entities: filteredEntities, + Relations: filteredRelations, + }, nil +} + +// openNodes returns entities with specified names and their interconnecting relations. +func (k knowledgeBase) openNodes(names []string) (KnowledgeGraph, error) { + graph, err := k.loadGraph() + if err != nil { + return KnowledgeGraph{}, err + } + + // Create map for quick name lookup + nameSet := make(map[string]bool) + for _, name := range names { + nameSet[name] = true + } + + // Filter entities + var filteredEntities []Entity + for _, entity := range graph.Entities { + if nameSet[entity.Name] { + filteredEntities = append(filteredEntities, entity) + } + } + + // Create map for quick entity lookup + filteredEntityNames := make(map[string]bool) + for _, entity := range filteredEntities { + filteredEntityNames[entity.Name] = true + } + + // Filter relations + var filteredRelations []Relation + for _, relation := range graph.Relations { + if filteredEntityNames[relation.From] && filteredEntityNames[relation.To] { + filteredRelations = append(filteredRelations, relation) + } + } + + return KnowledgeGraph{ + Entities: filteredEntities, + Relations: filteredRelations, + }, nil +} + +func (k knowledgeBase) CreateEntities(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[CreateEntitiesArgs]) (*mcp.CallToolResultFor[CreateEntitiesResult], error) { + var res mcp.CallToolResultFor[CreateEntitiesResult] + + entities, err := k.createEntities(params.Arguments.Entities) + if err != nil { + return nil, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Entities created successfully"}, + } + + res.StructuredContent = CreateEntitiesResult{ + Entities: entities, + } + + return &res, nil +} + +func (k knowledgeBase) CreateRelations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[CreateRelationsArgs]) (*mcp.CallToolResultFor[CreateRelationsResult], error) { + var res mcp.CallToolResultFor[CreateRelationsResult] + + relations, err := k.createRelations(params.Arguments.Relations) + if err != nil { + return nil, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Relations created successfully"}, + } + + res.StructuredContent = CreateRelationsResult{ + Relations: relations, + } + + return &res, nil +} + +func (k knowledgeBase) AddObservations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[AddObservationsArgs]) (*mcp.CallToolResultFor[AddObservationsResult], error) { + var res mcp.CallToolResultFor[AddObservationsResult] + + observations, err := k.addObservations(params.Arguments.Observations) + if err != nil { + return nil, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Observations added successfully"}, + } + + res.StructuredContent = AddObservationsResult{ + Observations: observations, + } + + return &res, nil +} + +func (k knowledgeBase) DeleteEntities(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[DeleteEntitiesArgs]) (*mcp.CallToolResultFor[struct{}], error) { + var res mcp.CallToolResultFor[struct{}] + + err := k.deleteEntities(params.Arguments.EntityNames) + if err != nil { + return nil, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Entities deleted successfully"}, + } + + return &res, nil +} + +func (k knowledgeBase) DeleteObservations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[DeleteObservationsArgs]) (*mcp.CallToolResultFor[struct{}], error) { + var res mcp.CallToolResultFor[struct{}] + + err := k.deleteObservations(params.Arguments.Deletions) + if err != nil { + return nil, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Observations deleted successfully"}, + } + + return &res, nil +} + +func (k knowledgeBase) DeleteRelations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[DeleteRelationsArgs]) (*mcp.CallToolResultFor[struct{}], error) { + var res mcp.CallToolResultFor[struct{}] + + err := k.deleteRelations(params.Arguments.Relations) + if err != nil { + return nil, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Relations deleted successfully"}, + } + + return &res, nil +} + +func (k knowledgeBase) ReadGraph(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[struct{}]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { + var res mcp.CallToolResultFor[KnowledgeGraph] + + graph, err := k.loadGraph() + if err != nil { + return nil, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Graph read successfully"}, + } + + res.StructuredContent = graph + return &res, nil +} + +func (k knowledgeBase) SearchNodes(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[SearchNodesArgs]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { + var res mcp.CallToolResultFor[KnowledgeGraph] + + graph, err := k.searchNodes(params.Arguments.Query) + if err != nil { + return nil, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Nodes searched successfully"}, + } + + res.StructuredContent = graph + return &res, nil +} + +func (k knowledgeBase) OpenNodes(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[OpenNodesArgs]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { + var res mcp.CallToolResultFor[KnowledgeGraph] + + graph, err := k.openNodes(params.Arguments.Names) + if err != nil { + return nil, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Nodes opened successfully"}, + } + + res.StructuredContent = graph + return &res, nil +} diff --git a/examples/memory/kb_test.go b/examples/memory/kb_test.go new file mode 100644 index 00000000..e4fbacc9 --- /dev/null +++ b/examples/memory/kb_test.go @@ -0,0 +1,703 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "reflect" + "slices" + "strings" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// stores provides test factories for both storage implementations. +func stores() map[string]func(t *testing.T) store { + return map[string]func(t *testing.T) store{ + "file": func(t *testing.T) store { + tempDir, err := os.MkdirTemp("", "kb-test-file-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + t.Cleanup(func() { os.RemoveAll(tempDir) }) + return &fileStore{path: filepath.Join(tempDir, "test-memory.json")} + }, + "memory": func(t *testing.T) store { + return &memoryStore{} + }, + } +} + +// TestKnowledgeBaseOperations verifies CRUD operations work correctly. +func TestKnowledgeBaseOperations(t *testing.T) { + for name, newStore := range stores() { + t.Run(name, func(t *testing.T) { + s := newStore(t) + kb := knowledgeBase{s: s} + + // Verify empty graph loads correctly + graph, err := kb.loadGraph() + if err != nil { + t.Fatalf("failed to load empty graph: %v", err) + } + if len(graph.Entities) != 0 || len(graph.Relations) != 0 { + t.Errorf("expected empty graph, got %+v", graph) + } + + // Create and verify entities + testEntities := []Entity{ + { + Name: "Alice", + EntityType: "Person", + Observations: []string{"Likes coffee"}, + }, + { + Name: "Bob", + EntityType: "Person", + Observations: []string{"Likes tea"}, + }, + } + + createdEntities, err := kb.createEntities(testEntities) + if err != nil { + t.Fatalf("failed to create entities: %v", err) + } + if len(createdEntities) != 2 { + t.Errorf("expected 2 created entities, got %d", len(createdEntities)) + } + + // Verify entities persist + graph, err = kb.loadGraph() + if err != nil { + t.Fatalf("failed to read graph: %v", err) + } + if len(graph.Entities) != 2 { + t.Errorf("expected 2 entities, got %d", len(graph.Entities)) + } + + // Create and verify relations + testRelations := []Relation{ + { + From: "Alice", + To: "Bob", + RelationType: "friend", + }, + } + + createdRelations, err := kb.createRelations(testRelations) + if err != nil { + t.Fatalf("failed to create relations: %v", err) + } + if len(createdRelations) != 1 { + t.Errorf("expected 1 created relation, got %d", len(createdRelations)) + } + + // Add observations to entities + testObservations := []Observation{ + { + EntityName: "Alice", + Contents: []string{"Works as developer", "Lives in New York"}, + }, + } + + addedObservations, err := kb.addObservations(testObservations) + if err != nil { + t.Fatalf("failed to add observations: %v", err) + } + if len(addedObservations) != 1 || len(addedObservations[0].Contents) != 2 { + t.Errorf("expected 1 observation with 2 contents, got %+v", addedObservations) + } + + // Search nodes by content + searchResult, err := kb.searchNodes("developer") + if err != nil { + t.Fatalf("failed to search nodes: %v", err) + } + if len(searchResult.Entities) != 1 || searchResult.Entities[0].Name != "Alice" { + t.Errorf("expected to find Alice when searching for 'developer', got %+v", searchResult) + } + + // Retrieve specific nodes + openResult, err := kb.openNodes([]string{"Bob"}) + if err != nil { + t.Fatalf("failed to open nodes: %v", err) + } + if len(openResult.Entities) != 1 || openResult.Entities[0].Name != "Bob" { + t.Errorf("expected to find Bob when opening 'Bob', got %+v", openResult) + } + + // Remove specific observations + deleteObs := []Observation{ + { + EntityName: "Alice", + Observations: []string{"Works as developer"}, + }, + } + err = kb.deleteObservations(deleteObs) + if err != nil { + t.Fatalf("failed to delete observations: %v", err) + } + + // Confirm observation removal + graph, _ = kb.loadGraph() + aliceIndex := slices.IndexFunc(graph.Entities, func(e Entity) bool { + return e.Name == "Alice" + }) + if aliceIndex == -1 { + t.Errorf("entity 'Alice' not found after deleting observation") + } else { + alice := graph.Entities[aliceIndex] + if slices.Contains(alice.Observations, "Works as developer") { + t.Errorf("observation 'Works as developer' should have been deleted") + } + } + + // Remove relations + err = kb.deleteRelations(testRelations) + if err != nil { + t.Fatalf("failed to delete relations: %v", err) + } + + // Confirm relation removal + graph, _ = kb.loadGraph() + if len(graph.Relations) != 0 { + t.Errorf("expected 0 relations after deletion, got %d", len(graph.Relations)) + } + + // Remove entities + err = kb.deleteEntities([]string{"Alice"}) + if err != nil { + t.Fatalf("failed to delete entities: %v", err) + } + + // Confirm entity removal + graph, _ = kb.loadGraph() + if len(graph.Entities) != 1 || graph.Entities[0].Name != "Bob" { + t.Errorf("expected only Bob to remain after deleting Alice, got %+v", graph.Entities) + } + }) + } +} + +// TestSaveAndLoadGraph ensures data persists correctly across save/load cycles. +func TestSaveAndLoadGraph(t *testing.T) { + for name, newStore := range stores() { + t.Run(name, func(t *testing.T) { + s := newStore(t) + kb := knowledgeBase{s: s} + + // Setup test data + testGraph := KnowledgeGraph{ + Entities: []Entity{ + { + Name: "Charlie", + EntityType: "Person", + Observations: []string{"Likes hiking"}, + }, + }, + Relations: []Relation{ + { + From: "Charlie", + To: "Mountains", + RelationType: "enjoys", + }, + }, + } + + // Persist to storage + err := kb.saveGraph(testGraph) + if err != nil { + t.Fatalf("failed to save graph: %v", err) + } + + // Reload from storage + loadedGraph, err := kb.loadGraph() + if err != nil { + t.Fatalf("failed to load graph: %v", err) + } + + // Verify data integrity + if !reflect.DeepEqual(testGraph, loadedGraph) { + t.Errorf("loaded graph does not match saved graph.\nExpected: %+v\nGot: %+v", testGraph, loadedGraph) + } + + // Test malformed data handling + if fs, ok := s.(*fileStore); ok { + err := os.WriteFile(fs.path, []byte("invalid json"), 0600) + if err != nil { + t.Fatalf("failed to write invalid json: %v", err) + } + + _, err = kb.loadGraph() + if err == nil { + t.Errorf("expected error when loading invalid JSON, got nil") + } + } + }) + } +} + +// TestDuplicateEntitiesAndRelations verifies duplicate prevention logic. +func TestDuplicateEntitiesAndRelations(t *testing.T) { + for name, newStore := range stores() { + t.Run(name, func(t *testing.T) { + s := newStore(t) + kb := knowledgeBase{s: s} + + // Setup initial state + initialEntities := []Entity{ + { + Name: "Dave", + EntityType: "Person", + Observations: []string{"Plays guitar"}, + }, + } + + _, err := kb.createEntities(initialEntities) + if err != nil { + t.Fatalf("failed to create initial entities: %v", err) + } + + // Attempt duplicate creation + duplicateEntities := []Entity{ + { + Name: "Dave", + EntityType: "Person", + Observations: []string{"Sings well"}, + }, + { + Name: "Eve", + EntityType: "Person", + Observations: []string{"Plays piano"}, + }, + } + + newEntities, err := kb.createEntities(duplicateEntities) + if err != nil { + t.Fatalf("failed when adding duplicate entities: %v", err) + } + + // Verify only new entities created + if len(newEntities) != 1 || newEntities[0].Name != "Eve" { + t.Errorf("expected only 'Eve' to be created, got %+v", newEntities) + } + + // Setup initial relation + initialRelation := []Relation{ + { + From: "Dave", + To: "Eve", + RelationType: "friend", + }, + } + + _, err = kb.createRelations(initialRelation) + if err != nil { + t.Fatalf("failed to create initial relation: %v", err) + } + + // Test relation deduplication + duplicateRelations := []Relation{ + { + From: "Dave", + To: "Eve", + RelationType: "friend", + }, + { + From: "Eve", + To: "Dave", + RelationType: "friend", + }, + } + + newRelations, err := kb.createRelations(duplicateRelations) + if err != nil { + t.Fatalf("failed when adding duplicate relations: %v", err) + } + + // Verify only new relations created + if len(newRelations) != 1 || newRelations[0].From != "Eve" || newRelations[0].To != "Dave" { + t.Errorf("expected only 'Eve->Dave' relation to be created, got %+v", newRelations) + } + }) + } +} + +// TestErrorHandling verifies proper error responses for invalid operations. +func TestErrorHandling(t *testing.T) { + t.Run("FileStoreWriteError", func(t *testing.T) { + // Test file write to invalid path + kb := knowledgeBase{ + s: &fileStore{path: filepath.Join("nonexistent", "directory", "file.json")}, + } + + testEntities := []Entity{ + {Name: "TestEntity"}, + } + + _, err := kb.createEntities(testEntities) + if err == nil { + t.Errorf("expected error when writing to non-existent directory, got nil") + } + }) + + for name, newStore := range stores() { + t.Run(fmt.Sprintf("AddObservationToNonExistentEntity_%s", name), func(t *testing.T) { + s := newStore(t) + kb := knowledgeBase{s: s} + + // Setup valid entity for comparison + _, err := kb.createEntities([]Entity{{Name: "RealEntity"}}) + if err != nil { + t.Fatalf("failed to create test entity: %v", err) + } + + // Test invalid entity reference + nonExistentObs := []Observation{ + { + EntityName: "NonExistentEntity", + Contents: []string{"This shouldn't work"}, + }, + } + + _, err = kb.addObservations(nonExistentObs) + if err == nil { + t.Errorf("expected error when adding observations to non-existent entity, got nil") + } + }) + } +} + +// TestFileFormatting verifies the JSON storage format structure. +func TestFileFormatting(t *testing.T) { + for name, newStore := range stores() { + t.Run(name, func(t *testing.T) { + s := newStore(t) + kb := knowledgeBase{s: s} + + // Setup test entity + testEntities := []Entity{ + { + Name: "FileTest", + EntityType: "TestEntity", + Observations: []string{"Test observation"}, + }, + } + + _, err := kb.createEntities(testEntities) + if err != nil { + t.Fatalf("failed to create test entity: %v", err) + } + + // Extract raw storage data + data, err := s.Read() + if err != nil { + t.Fatalf("failed to read from store: %v", err) + } + + // Validate JSON format + var items []kbItem + err = json.Unmarshal(data, &items) + if err != nil { + t.Fatalf("failed to parse store data JSON: %v", err) + } + + // Check data structure + if len(items) != 1 { + t.Fatalf("expected 1 item in memory file, got %d", len(items)) + } + + item := items[0] + if item.Type != "entity" || + item.Name != "FileTest" || + item.EntityType != "TestEntity" || + len(item.Observations) != 1 || + item.Observations[0] != "Test observation" { + t.Errorf("store item format incorrect: %+v", item) + } + }) + } +} + +// TestMCPServerIntegration tests the knowledge base through MCP server layer. +func TestMCPServerIntegration(t *testing.T) { + for name, newStore := range stores() { + t.Run(name, func(t *testing.T) { + s := newStore(t) + kb := knowledgeBase{s: s} + + // Create mock server session + ctx := context.Background() + serverSession := &mcp.ServerSession{} + + // Test CreateEntities through MCP + createEntitiesParams := &mcp.CallToolParamsFor[CreateEntitiesArgs]{ + Arguments: CreateEntitiesArgs{ + Entities: []Entity{ + { + Name: "TestPerson", + EntityType: "Person", + Observations: []string{"Likes testing"}, + }, + }, + }, + } + + createResult, err := kb.CreateEntities(ctx, serverSession, createEntitiesParams) + if err != nil { + t.Fatalf("MCP CreateEntities failed: %v", err) + } + if createResult.IsError { + t.Fatalf("MCP CreateEntities returned error: %v", createResult.Content) + } + if len(createResult.StructuredContent.Entities) != 1 { + t.Errorf("expected 1 entity created, got %d", len(createResult.StructuredContent.Entities)) + } + + // Test ReadGraph through MCP + readParams := &mcp.CallToolParamsFor[struct{}]{} + readResult, err := kb.ReadGraph(ctx, serverSession, readParams) + if err != nil { + t.Fatalf("MCP ReadGraph failed: %v", err) + } + if readResult.IsError { + t.Fatalf("MCP ReadGraph returned error: %v", readResult.Content) + } + if len(readResult.StructuredContent.Entities) != 1 { + t.Errorf("expected 1 entity in graph, got %d", len(readResult.StructuredContent.Entities)) + } + + // Test CreateRelations through MCP + createRelationsParams := &mcp.CallToolParamsFor[CreateRelationsArgs]{ + Arguments: CreateRelationsArgs{ + Relations: []Relation{ + { + From: "TestPerson", + To: "Testing", + RelationType: "likes", + }, + }, + }, + } + + relationsResult, err := kb.CreateRelations(ctx, serverSession, createRelationsParams) + if err != nil { + t.Fatalf("MCP CreateRelations failed: %v", err) + } + if relationsResult.IsError { + t.Fatalf("MCP CreateRelations returned error: %v", relationsResult.Content) + } + if len(relationsResult.StructuredContent.Relations) != 1 { + t.Errorf("expected 1 relation created, got %d", len(relationsResult.StructuredContent.Relations)) + } + + // Test AddObservations through MCP + addObsParams := &mcp.CallToolParamsFor[AddObservationsArgs]{ + Arguments: AddObservationsArgs{ + Observations: []Observation{ + { + EntityName: "TestPerson", + Contents: []string{"Works remotely", "Drinks coffee"}, + }, + }, + }, + } + + obsResult, err := kb.AddObservations(ctx, serverSession, addObsParams) + if err != nil { + t.Fatalf("MCP AddObservations failed: %v", err) + } + if obsResult.IsError { + t.Fatalf("MCP AddObservations returned error: %v", obsResult.Content) + } + if len(obsResult.StructuredContent.Observations) != 1 { + t.Errorf("expected 1 observation result, got %d", len(obsResult.StructuredContent.Observations)) + } + + // Test SearchNodes through MCP + searchParams := &mcp.CallToolParamsFor[SearchNodesArgs]{ + Arguments: SearchNodesArgs{ + Query: "coffee", + }, + } + + searchResult, err := kb.SearchNodes(ctx, serverSession, searchParams) + if err != nil { + t.Fatalf("MCP SearchNodes failed: %v", err) + } + if searchResult.IsError { + t.Fatalf("MCP SearchNodes returned error: %v", searchResult.Content) + } + if len(searchResult.StructuredContent.Entities) != 1 { + t.Errorf("expected 1 entity from search, got %d", len(searchResult.StructuredContent.Entities)) + } + + // Test OpenNodes through MCP + openParams := &mcp.CallToolParamsFor[OpenNodesArgs]{ + Arguments: OpenNodesArgs{ + Names: []string{"TestPerson"}, + }, + } + + openResult, err := kb.OpenNodes(ctx, serverSession, openParams) + if err != nil { + t.Fatalf("MCP OpenNodes failed: %v", err) + } + if openResult.IsError { + t.Fatalf("MCP OpenNodes returned error: %v", openResult.Content) + } + if len(openResult.StructuredContent.Entities) != 1 { + t.Errorf("expected 1 entity from open, got %d", len(openResult.StructuredContent.Entities)) + } + + // Test DeleteObservations through MCP + deleteObsParams := &mcp.CallToolParamsFor[DeleteObservationsArgs]{ + Arguments: DeleteObservationsArgs{ + Deletions: []Observation{ + { + EntityName: "TestPerson", + Observations: []string{"Works remotely"}, + }, + }, + }, + } + + deleteObsResult, err := kb.DeleteObservations(ctx, serverSession, deleteObsParams) + if err != nil { + t.Fatalf("MCP DeleteObservations failed: %v", err) + } + if deleteObsResult.IsError { + t.Fatalf("MCP DeleteObservations returned error: %v", deleteObsResult.Content) + } + + // Test DeleteRelations through MCP + deleteRelParams := &mcp.CallToolParamsFor[DeleteRelationsArgs]{ + Arguments: DeleteRelationsArgs{ + Relations: []Relation{ + { + From: "TestPerson", + To: "Testing", + RelationType: "likes", + }, + }, + }, + } + + deleteRelResult, err := kb.DeleteRelations(ctx, serverSession, deleteRelParams) + if err != nil { + t.Fatalf("MCP DeleteRelations failed: %v", err) + } + if deleteRelResult.IsError { + t.Fatalf("MCP DeleteRelations returned error: %v", deleteRelResult.Content) + } + + // Test DeleteEntities through MCP + deleteEntParams := &mcp.CallToolParamsFor[DeleteEntitiesArgs]{ + Arguments: DeleteEntitiesArgs{ + EntityNames: []string{"TestPerson"}, + }, + } + + deleteEntResult, err := kb.DeleteEntities(ctx, serverSession, deleteEntParams) + if err != nil { + t.Fatalf("MCP DeleteEntities failed: %v", err) + } + if deleteEntResult.IsError { + t.Fatalf("MCP DeleteEntities returned error: %v", deleteEntResult.Content) + } + + // Verify final state + finalRead, err := kb.ReadGraph(ctx, serverSession, readParams) + if err != nil { + t.Fatalf("Final MCP ReadGraph failed: %v", err) + } + if len(finalRead.StructuredContent.Entities) != 0 { + t.Errorf("expected empty graph after deletion, got %d entities", len(finalRead.StructuredContent.Entities)) + } + }) + } +} + +// TestMCPErrorHandling tests error scenarios through MCP layer. +func TestMCPErrorHandling(t *testing.T) { + for name, newStore := range stores() { + t.Run(name, func(t *testing.T) { + s := newStore(t) + kb := knowledgeBase{s: s} + + ctx := context.Background() + serverSession := &mcp.ServerSession{} + + // Test adding observations to non-existent entity + addObsParams := &mcp.CallToolParamsFor[AddObservationsArgs]{ + Arguments: AddObservationsArgs{ + Observations: []Observation{ + { + EntityName: "NonExistentEntity", + Contents: []string{"This should fail"}, + }, + }, + }, + } + + _, err := kb.AddObservations(ctx, serverSession, addObsParams) + if err == nil { + t.Errorf("expected MCP AddObservations to return error for non-existent entity") + } else { + // Verify the error message contains expected text + want := "entity with name NonExistentEntity not found" + if !strings.Contains(err.Error(), want) { + t.Errorf("expected error message to contain '%s', got: %v", want, err) + } + } + }) + } +} + +// TestMCPResponseFormat verifies MCP response format consistency. +func TestMCPResponseFormat(t *testing.T) { + s := &memoryStore{} + kb := knowledgeBase{s: s} + + ctx := context.Background() + serverSession := &mcp.ServerSession{} + + // Test CreateEntities response format + createParams := &mcp.CallToolParamsFor[CreateEntitiesArgs]{ + Arguments: CreateEntitiesArgs{ + Entities: []Entity{ + {Name: "FormatTest", EntityType: "Test"}, + }, + }, + } + + result, err := kb.CreateEntities(ctx, serverSession, createParams) + if err != nil { + t.Fatalf("CreateEntities failed: %v", err) + } + + // Verify response has both Content and StructuredContent + if len(result.Content) == 0 { + t.Errorf("expected Content field to be populated") + } + if len(result.StructuredContent.Entities) == 0 { + t.Errorf("expected StructuredContent.Entities to be populated") + } + + // Verify Content contains simple success message + if textContent, ok := result.Content[0].(*mcp.TextContent); ok { + expectedMessage := "Entities created successfully" + if textContent.Text != expectedMessage { + t.Errorf("expected Content field to contain '%s', got '%s'", expectedMessage, textContent.Text) + } + } else { + t.Errorf("expected Content[0] to be TextContent") + } +} diff --git a/examples/memory/main.go b/examples/memory/main.go new file mode 100644 index 00000000..d3d78110 --- /dev/null +++ b/examples/memory/main.go @@ -0,0 +1,145 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "flag" + "log" + "net/http" + "os" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") + memoryFilePath = flag.String("memory", "", "if set, persist the knowledge base to this file; otherwise, it will be stored in memory and lost on exit") +) + +// HiArgs defines arguments for the greeting tool. +type HiArgs struct { + Name string `json:"name"` +} + +// CreateEntitiesArgs defines the create entities tool parameters. +type CreateEntitiesArgs struct { + Entities []Entity `json:"entities" mcp:"entities to create"` +} + +// CreateEntitiesResult returns newly created entities. +type CreateEntitiesResult struct { + Entities []Entity `json:"entities"` +} + +// CreateRelationsArgs defines the create relations tool parameters. +type CreateRelationsArgs struct { + Relations []Relation `json:"relations" mcp:"relations to create"` +} + +// CreateRelationsResult returns newly created relations. +type CreateRelationsResult struct { + Relations []Relation `json:"relations"` +} + +// AddObservationsArgs defines the add observations tool parameters. +type AddObservationsArgs struct { + Observations []Observation `json:"observations" mcp:"observations to add"` +} + +// AddObservationsResult returns newly added observations. +type AddObservationsResult struct { + Observations []Observation `json:"observations"` +} + +// DeleteEntitiesArgs defines the delete entities tool parameters. +type DeleteEntitiesArgs struct { + EntityNames []string `json:"entityNames" mcp:"entities to delete"` +} + +// DeleteObservationsArgs defines the delete observations tool parameters. +type DeleteObservationsArgs struct { + Deletions []Observation `json:"deletions" mcp:"obeservations to delete"` +} + +// DeleteRelationsArgs defines the delete relations tool parameters. +type DeleteRelationsArgs struct { + Relations []Relation `json:"relations" mcp:"relations to delete"` +} + +// SearchNodesArgs defines the search nodes tool parameters. +type SearchNodesArgs struct { + Query string `json:"query" mcp:"query string"` +} + +// OpenNodesArgs defines the open nodes tool parameters. +type OpenNodesArgs struct { + Names []string `json:"names" mcp:"names of nodes to open"` +} + +func main() { + flag.Parse() + + // Initialize storage backend + var kbStore store + kbStore = &memoryStore{} + if *memoryFilePath != "" { + kbStore = &fileStore{path: *memoryFilePath} + } + kb := knowledgeBase{s: kbStore} + + // Setup MCP server with knowledge base tools + server := mcp.NewServer("memory", "v0.0.1", nil) + mcp.AddTool(server, &mcp.Tool{ + Name: "create_entities", + Description: "Create multiple new entities in the knowledge graph", + }, kb.CreateEntities) + mcp.AddTool(server, &mcp.Tool{ + Name: "create_relations", + Description: "Create multiple new relations between entities", + }, kb.CreateRelations) + mcp.AddTool(server, &mcp.Tool{ + Name: "add_observations", + Description: "Add new observations to existing entities", + }, kb.AddObservations) + mcp.AddTool(server, &mcp.Tool{ + Name: "delete_entities", + Description: "Remove entities and their relations", + }, kb.DeleteEntities) + mcp.AddTool(server, &mcp.Tool{ + Name: "delete_observations", + Description: "Remove specific observations from entities", + }, kb.DeleteObservations) + mcp.AddTool(server, &mcp.Tool{ + Name: "delete_relations", + Description: "Remove specific relations from the graph", + }, kb.DeleteRelations) + mcp.AddTool(server, &mcp.Tool{ + Name: "read_graph", + Description: "Read the entire knowledge graph", + }, kb.ReadGraph) + mcp.AddTool(server, &mcp.Tool{ + Name: "search_nodes", + Description: "Search for nodes based on query", + }, kb.SearchNodes) + mcp.AddTool(server, &mcp.Tool{ + Name: "open_nodes", + Description: "Retrieve specific nodes by name", + }, kb.OpenNodes) + + // Start server with appropriate transport + if *httpAddr != "" { + handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, nil) + log.Printf("MCP handler listening at %s", *httpAddr) + http.ListenAndServe(*httpAddr, handler) + } else { + t := mcp.NewLoggingTransport(mcp.NewStdioTransport(), os.Stderr) + if err := server.Run(context.Background(), t); err != nil { + log.Printf("Server failed: %v", err) + } + } +} From fbff31af141350acb70b17740de56b01187a8741 Mon Sep 17 00:00:00 2001 From: cryo Date: Wed, 9 Jul 2025 03:52:54 +0800 Subject: [PATCH 008/221] mcp/Resource: Refactor URI Template Matching with uritemplate Library (#100) 1. Replaced the `uriTemplateToRegexp` function with the use of [yosida95/uritemplate](https://github.com/yosida95/uritemplate) to handle URI template matching as per the request. 2. Updated the relevant test cases to verify the new implementation and fixed an invalid URI template (`"file:///{}/{a}/{b}"`) that doesn't conform to RFC 6570 due to an empty `{}` expression. 3. Updated `go.mod` and `go.sum` to reflect the new dependency. Fixes #11. --- go.mod | 1 + go.sum | 2 ++ mcp/resource.go | 63 +++----------------------------------------- mcp/resource_test.go | 12 +++------ 4 files changed, 10 insertions(+), 68 deletions(-) diff --git a/go.mod b/go.mod index 24e187ab..9bf8c151 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,6 @@ go 1.23.0 require ( github.com/google/go-cmp v0.7.0 + github.com/yosida95/uritemplate/v3 v3.0.2 golang.org/x/tools v0.34.0 ) diff --git a/go.sum b/go.sum index 6c6c2a5d..7d2f581d 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= diff --git a/mcp/resource.go b/mcp/resource.go index 4202fdac..590e0672 100644 --- a/mcp/resource.go +++ b/mcp/resource.go @@ -13,11 +13,11 @@ import ( "net/url" "os" "path/filepath" - "regexp" "strings" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/util" + "github.com/yosida95/uritemplate/v3" ) // A serverResource associates a Resource with its handler. @@ -155,67 +155,10 @@ func fileRoot(root *Root) (_ string, err error) { } // Matches reports whether the receiver's uri template matches the uri. -// TODO: use "github.com/yosida95/uritemplate/v3" func (sr *serverResourceTemplate) Matches(uri string) bool { - re, err := uriTemplateToRegexp(sr.resourceTemplate.URITemplate) + tmpl, err := uritemplate.New(sr.resourceTemplate.URITemplate) if err != nil { return false } - return re.MatchString(uri) -} - -func uriTemplateToRegexp(uriTemplate string) (*regexp.Regexp, error) { - pat := uriTemplate - var b strings.Builder - b.WriteByte('^') - seen := map[string]bool{} - for len(pat) > 0 { - literal, rest, ok := strings.Cut(pat, "{") - b.WriteString(regexp.QuoteMeta(literal)) - if !ok { - break - } - expr, rest, ok := strings.Cut(rest, "}") - if !ok { - return nil, errors.New("missing '}'") - } - pat = rest - if strings.ContainsRune(expr, ',') { - return nil, errors.New("can't handle commas in expressions") - } - if strings.ContainsRune(expr, ':') { - return nil, errors.New("can't handle prefix modifiers in expressions") - } - if len(expr) > 0 && expr[len(expr)-1] == '*' { - return nil, errors.New("can't handle explode modifiers in expressions") - } - - // These sets of valid characters aren't accurate. - // See https://datatracker.ietf.org/doc/html/rfc6570. - var re, name string - first := byte(0) - if len(expr) > 0 { - first = expr[0] - } - switch first { - default: - // {var} doesn't match slashes. (It should also fail to match other characters, - // but this simplified implementation doesn't handle that.) - re = `[^/]*` - name = expr - case '+': - // {+var} matches anything, even slashes - re = `.*` - name = expr[1:] - case '#', '.', '/', ';', '?', '&': - return nil, fmt.Errorf("prefix character %c unsupported", first) - } - if seen[name] { - return nil, fmt.Errorf("can't handle duplicate name %q", name) - } - seen[name] = true - b.WriteString(re) - } - b.WriteByte('$') - return regexp.Compile(b.String()) + return tmpl.Regexp().MatchString(uri) } diff --git a/mcp/resource_test.go b/mcp/resource_test.go index cb5e4fb9..74d82f12 100644 --- a/mcp/resource_test.go +++ b/mcp/resource_test.go @@ -119,18 +119,14 @@ func TestTemplateMatch(t *testing.T) { template string want bool }{ - {"file:///{}/{a}/{b}", true}, + {"file:///{}/{a}/{b}", false}, // invalid: empty variable expression "{}" is not allowed in RFC 6570 {"file:///{a}/{b}", false}, {"file:///{+path}", true}, {"file:///{a}/{+path}", true}, } { - re, err := uriTemplateToRegexp(tt.template) - if err != nil { - t.Fatalf("%s: %v", tt.template, err) - } - got := re.MatchString(uri) - if got != tt.want { - t.Errorf("%s: got %t, want %t", tt.template, got, tt.want) + resourceTmpl := serverResourceTemplate{resourceTemplate: &ResourceTemplate{URITemplate: tt.template}} + if matched := resourceTmpl.Matches(uri); matched != tt.want { + t.Errorf("%s: got %t, want %t", tt.template, matched, tt.want) } } } From 2b07560be04b02a658334daaf4556b27ba143d31 Mon Sep 17 00:00:00 2001 From: Christopher Speller Date: Wed, 9 Jul 2025 08:54:53 -0700 Subject: [PATCH 009/221] mcp: fix sseClientConn to use custom HTTP client (#105) When attempting to create an OAuth authentication on top of the SSE transport I discovered the sseClientConn was using http.DefaultClient for requests instead of the custom client provided via SSEClientTransportOptions. This change stores the HTTP client in the sseClientConn struct and uses it for all requests. Also adds test to verify the custom client is properly used. --- mcp/sse.go | 10 ++++++---- mcp/sse_test.go | 25 ++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/mcp/sse.go b/mcp/sse.go index 0a1f9b1b..d1b52599 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -391,6 +391,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { // From here on, the stream takes ownership of resp.Body. s := &sseClientConn{ + client: httpClient, sseEndpoint: c.sseEndpoint, msgEndpoint: msgEndpoint, incoming: make(chan []byte, 100), @@ -511,9 +512,10 @@ func scanEvents(r io.Reader) iter.Seq2[event, error] { // - Reads are SSE 'message' events, and pushes them onto a buffered channel. // - Close terminates the GET request. type sseClientConn struct { - sseEndpoint *url.URL // SSE endpoint for the GET - msgEndpoint *url.URL // session endpoint for POSTs - incoming chan []byte // queue of incoming messages + client *http.Client // HTTP client to use for requests + sseEndpoint *url.URL // SSE endpoint for the GET + msgEndpoint *url.URL // session endpoint for POSTs + incoming chan []byte // queue of incoming messages mu sync.Mutex body io.ReadCloser // body of the hanging GET @@ -564,7 +566,7 @@ func (c *sseClientConn) Write(ctx context.Context, msg JSONRPCMessage) error { return err } req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) + resp, err := c.client.Do(req) if err != nil { return err } diff --git a/mcp/sse_test.go b/mcp/sse_test.go index e1df9536..153185d3 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -10,6 +10,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync/atomic" "testing" "github.com/google/go-cmp/cmp" @@ -34,7 +35,17 @@ func TestSSEServer(t *testing.T) { httpServer := httptest.NewServer(sseHandler) defer httpServer.Close() - clientTransport := NewSSEClientTransport(httpServer.URL, nil) + var customClientUsed int64 + customClient := &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + atomic.AddInt64(&customClientUsed, 1) + return http.DefaultTransport.RoundTrip(req) + }), + } + + clientTransport := NewSSEClientTransport(httpServer.URL, &SSEClientTransportOptions{ + HTTPClient: customClient, + }) c := NewClient("testClient", "v1.0.0", nil) cs, err := c.Connect(ctx, clientTransport) @@ -61,6 +72,11 @@ func TestSSEServer(t *testing.T) { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) } + // Verify that customClient was used + if atomic.LoadInt64(&customClientUsed) == 0 { + t.Error("Expected custom HTTP client to be used, but it wasn't") + } + // Test that closing either end of the connection terminates the other // end. if closeServerFirst { @@ -162,3 +178,10 @@ func TestScanEvents(t *testing.T) { }) } } + +// roundTripperFunc is a helper to create a custom RoundTripper +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} From e27a36ed320bd071b5a1e5b6f18ebb6dbb2ccf06 Mon Sep 17 00:00:00 2001 From: Albert Sundjaja <22909314+albertsundjaja@users.noreply.github.com> Date: Thu, 10 Jul 2025 22:55:45 +1000 Subject: [PATCH 010/221] jsonchema: add cycle detection (#97) This PR add cycle detection where it return error instead of waiting for stack overflow. This is a fix for https://github.com/modelcontextprotocol/go-sdk/issues/77. The root cause of the issue was the type being processed is deeply nested and has cycle. Hence, the heap memory growth is faster than the stack growth, it went OOM before stack overflow panic can kick in. --- jsonschema/infer.go | 22 +++++++--- jsonschema/infer_test.go | 93 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 5 deletions(-) diff --git a/jsonschema/infer.go b/jsonschema/infer.go index 1334bdf1..d1c1a5fb 100644 --- a/jsonschema/infer.go +++ b/jsonschema/infer.go @@ -37,9 +37,11 @@ import ( // - unsafe pointers // // The types must not have cycles. +// It will return an error if there is a cycle in the types. func For[T any]() (*Schema, error) { // TODO: consider skipping incompatible fields, instead of failing. - s, err := forType(reflect.TypeFor[T]()) + seen := make(map[reflect.Type]bool) + s, err := forType(reflect.TypeFor[T](), seen) if err != nil { var z T return nil, fmt.Errorf("For[%T](): %w", z, err) @@ -47,7 +49,7 @@ func For[T any]() (*Schema, error) { return s, nil } -func forType(t reflect.Type) (*Schema, error) { +func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { // Follow pointers: the schema for *T is almost the same as for T, except that // an explicit JSON "null" is allowed for the pointer. allowNull := false @@ -56,6 +58,16 @@ func forType(t reflect.Type) (*Schema, error) { t = t.Elem() } + // Check for cycles + // User defined types have a name, so we can skip those that are natively defined + if t.Name() != "" { + if seen[t] { + return nil, fmt.Errorf("cycle detected for type %v", t) + } + seen[t] = true + defer delete(seen, t) + } + var ( s = new(Schema) err error @@ -81,14 +93,14 @@ func forType(t reflect.Type) (*Schema, error) { return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind()) } s.Type = "object" - s.AdditionalProperties, err = forType(t.Elem()) + s.AdditionalProperties, err = forType(t.Elem(), seen) if err != nil { return nil, fmt.Errorf("computing map value schema: %v", err) } case reflect.Slice, reflect.Array: s.Type = "array" - s.Items, err = forType(t.Elem()) + s.Items, err = forType(t.Elem(), seen) if err != nil { return nil, fmt.Errorf("computing element schema: %v", err) } @@ -114,7 +126,7 @@ func forType(t reflect.Type) (*Schema, error) { if s.Properties == nil { s.Properties = make(map[string]*Schema) } - s.Properties[info.Name], err = forType(field.Type) + s.Properties[info.Name], err = forType(field.Type, seen) if err != nil { return nil, err } diff --git a/jsonschema/infer_test.go b/jsonschema/infer_test.go index 9325b832..0b1b769a 100644 --- a/jsonschema/infer_test.go +++ b/jsonschema/infer_test.go @@ -91,3 +91,96 @@ func TestForType(t *testing.T) { }) } } + +func TestForWithMutation(t *testing.T) { + // This test ensures that the cached schema is not mutated when the caller + // mutates the returned schema. + type S struct { + A int + } + type T struct { + A int `json:"A"` + B map[string]int + C []S + D [3]S + E *bool + } + s, err := jsonschema.For[T]() + if err != nil { + t.Fatalf("For: %v", err) + } + s.Required[0] = "mutated" + s.Properties["A"].Type = "mutated" + s.Properties["C"].Items.Type = "mutated" + s.Properties["D"].MaxItems = jsonschema.Ptr(10) + s.Properties["D"].MinItems = jsonschema.Ptr(10) + s.Properties["E"].Types[0] = "mutated" + + s2, err := jsonschema.For[T]() + if err != nil { + t.Fatalf("For: %v", err) + } + if s2.Properties["A"].Type == "mutated" { + t.Fatalf("ForWithMutation: expected A.Type to not be mutated") + } + if s2.Properties["B"].AdditionalProperties.Type == "mutated" { + t.Fatalf("ForWithMutation: expected B.AdditionalProperties.Type to not be mutated") + } + if s2.Properties["C"].Items.Type == "mutated" { + t.Fatalf("ForWithMutation: expected C.Items.Type to not be mutated") + } + if *s2.Properties["D"].MaxItems == 10 { + t.Fatalf("ForWithMutation: expected D.MaxItems to not be mutated") + } + if *s2.Properties["D"].MinItems == 10 { + t.Fatalf("ForWithMutation: expected D.MinItems to not be mutated") + } + if s2.Properties["E"].Types[0] == "mutated" { + t.Fatalf("ForWithMutation: expected E.Types[0] to not be mutated") + } + if s2.Required[0] == "mutated" { + t.Fatalf("ForWithMutation: expected Required[0] to not be mutated") + } +} + +type x struct { + Y y +} +type y struct { + X []x +} + +func TestForWithCycle(t *testing.T) { + type a []*a + type b1 struct{ b *b1 } // unexported field should be skipped + type b2 struct{ B *b2 } + type c1 struct{ c map[string]*c1 } // unexported field should be skipped + type c2 struct{ C map[string]*c2 } + + tests := []struct { + name string + shouldErr bool + fn func() error + }{ + {"slice alias (a)", true, func() error { _, err := jsonschema.For[a](); return err }}, + {"unexported self cycle (b1)", false, func() error { _, err := jsonschema.For[b1](); return err }}, + {"exported self cycle (b2)", true, func() error { _, err := jsonschema.For[b2](); return err }}, + {"unexported map self cycle (c1)", false, func() error { _, err := jsonschema.For[c1](); return err }}, + {"exported map self cycle (c2)", true, func() error { _, err := jsonschema.For[c2](); return err }}, + {"cross-cycle x -> y -> x", true, func() error { _, err := jsonschema.For[x](); return err }}, + {"cross-cycle y -> x -> y", true, func() error { _, err := jsonschema.For[y](); return err }}, + } + + for _, test := range tests { + test := test // prevent loop shadowing + t.Run(test.name, func(t *testing.T) { + err := test.fn() + if test.shouldErr && err == nil { + t.Errorf("expected cycle error, got nil") + } + if !test.shouldErr && err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} From 221febd372b4d2ff45c7190e9b28af2db0f81f89 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Thu, 10 Jul 2025 09:06:13 -0400 Subject: [PATCH 011/221] examples/rate-limiting: add example for session based rate limiting (#87) Added an example for session based rate limiting after session.ID() was exposed. Fixes #22 --- examples/rate-limiting/go.mod | 2 +- examples/rate-limiting/go.sum | 4 ++-- examples/rate-limiting/main.go | 42 ++++++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/examples/rate-limiting/go.mod b/examples/rate-limiting/go.mod index f3cf7aa1..5b76b16e 100644 --- a/examples/rate-limiting/go.mod +++ b/examples/rate-limiting/go.mod @@ -5,6 +5,6 @@ go 1.23.0 toolchain go1.24.4 require ( - github.com/modelcontextprotocol/go-sdk v0.0.0-20250625185707-09181c2c2e89 + github.com/modelcontextprotocol/go-sdk v0.1.0 golang.org/x/time v0.12.0 ) diff --git a/examples/rate-limiting/go.sum b/examples/rate-limiting/go.sum index c7027682..d73f0a54 100644 --- a/examples/rate-limiting/go.sum +++ b/examples/rate-limiting/go.sum @@ -1,7 +1,7 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/modelcontextprotocol/go-sdk v0.0.0-20250625185707-09181c2c2e89 h1:kUGBYP25FTv3ZRBhLT4iQvtx4FDl7hPkWe3isYrMxyo= -github.com/modelcontextprotocol/go-sdk v0.0.0-20250625185707-09181c2c2e89/go.mod h1:DcXfbr7yl7e35oMpzHfKw2nUYRjhIGS2uou/6tdsTB0= +github.com/modelcontextprotocol/go-sdk v0.1.0 h1:ItzbFWYNt4EHcUrScX7P8JPASn1FVYb29G773Xkl+IU= +github.com/modelcontextprotocol/go-sdk v0.1.0/go.mod h1:DcXfbr7yl7e35oMpzHfKw2nUYRjhIGS2uou/6tdsTB0= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= diff --git a/examples/rate-limiting/main.go b/examples/rate-limiting/main.go index 7e91b79f..c3265c4c 100644 --- a/examples/rate-limiting/main.go +++ b/examples/rate-limiting/main.go @@ -7,6 +7,8 @@ package main import ( "context" "errors" + "log" + "sync" "time" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -43,6 +45,43 @@ func PerMethodRateLimiterMiddleware[S mcp.Session](limiters map[string]*rate.Lim } } +// PerSessionRateLimiterMiddleware creates a middleware that applies rate limiting +// on a per-session basis for receiving requests. +func PerSessionRateLimiterMiddleware[S mcp.Session](limit rate.Limit, burst int) mcp.Middleware[S] { + // A map to store limiters, keyed by the session ID. + var ( + sessionLimiters = make(map[string]*rate.Limiter) + mu sync.Mutex + ) + + return func(next mcp.MethodHandler[S]) mcp.MethodHandler[S] { + return func(ctx context.Context, session S, method string, params mcp.Params) (mcp.Result, error) { + // It's possible that session.ID() may be empty at this point in time + // for some transports (e.g., stdio) or until the MCP initialize handshake + // has completed. + sessionID := session.ID() + if sessionID == "" { + // In this situation, you could apply a single global identifier + // if session ID is empty or bypass the rate limiter. + // In this example, we bypass the rate limiter. + log.Printf("Warning: Session ID is empty for method %q. Skipping per-session rate limiting.", method) + return next(ctx, session, method, params) // Skip limiting if ID is unavailable + } + mu.Lock() + limiter, ok := sessionLimiters[sessionID] + if !ok { + limiter = rate.NewLimiter(limit, burst) + sessionLimiters[sessionID] = limiter + } + mu.Unlock() + if !limiter.Allow() { + return nil, errors.New("JSON RPC overloaded") + } + return next(ctx, session, method, params) + } + } +} + func main() { server := mcp.NewServer("greeter1", "v0.0.1", nil) server.AddReceivingMiddleware(GlobalRateLimiterMiddleware[*mcp.ServerSession](rate.NewLimiter(rate.Every(time.Second/5), 10))) @@ -50,5 +89,8 @@ func main() { "callTool": rate.NewLimiter(rate.Every(time.Second), 5), // once a second with a burst up to 5 "listTools": rate.NewLimiter(rate.Every(time.Minute), 20), // once a minute with a burst up to 20 })) + server.AddReceivingMiddleware(PerSessionRateLimiterMiddleware[*mcp.ServerSession](rate.Every(time.Second/5), 10)) // Run Server logic. + log.Println("MCP Server instance created with Middleware (but not running).") + log.Println("This example demonstrates configuration, not live interaction.") } From a1a3510315618c783933767c5ecd8267e755c875 Mon Sep 17 00:00:00 2001 From: Chris Hoff Date: Thu, 10 Jul 2025 10:07:25 -0400 Subject: [PATCH 012/221] mcp/server: make Run return on context cancel (#111) Make `Server.Run` return when the provided context is canceled. Add tests for when the `Run` context is cancelled and also for when the server process receives a signal. Fixes #107. --- mcp/cmd_test.go | 118 +++++++++++++++++++++++++++++++++++++++++++----- mcp/server.go | 17 ++++++- 2 files changed, 121 insertions(+), 14 deletions(-) diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index 496694a5..bfae0c60 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -6,11 +6,13 @@ package mcp_test import ( "context" + "errors" "log" "os" "os/exec" "runtime" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -37,36 +39,103 @@ func runServer() { } } -func TestCmdTransport(t *testing.T) { - // Conservatively, limit to major OS where we know that os.Exec is - // supported. - switch runtime.GOOS { - case "darwin", "linux", "windows": - default: - t.Skip("unsupported OS") +func TestServerRunContextCancel(t *testing.T) { + server := mcp.NewServer("greeter", "v0.0.1", nil) + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverTransport, clientTransport := mcp.NewInMemoryTransports() + + // run the server and capture the exit error + onServerExit := make(chan error) + go func() { + onServerExit <- server.Run(ctx, serverTransport) + }() + + // send a ping to the server to ensure it's running + client := mcp.NewClient("client", "v0.0.1", nil) + session, err := client.Connect(ctx, clientTransport) + if err != nil { + t.Fatal(err) + } + if err := session.Ping(context.Background(), nil); err != nil { + t.Fatal(err) } + // cancel the context to stop the server + cancel() + + // wait for the server to exit + // TODO: use synctest when availble + select { + case <-time.After(5 * time.Second): + t.Fatal("server did not exit after context cancellation") + case err := <-onServerExit: + if !errors.Is(err, context.Canceled) { + t.Fatalf("server did not exit after context cancellation, got error: %v", err) + } + } +} + +func TestServerInterrupt(t *testing.T) { + requireExec(t) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - exe, err := os.Executable() + cmd := createServerCommand(t) + + client := mcp.NewClient("client", "v0.0.1", nil) + session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) if err != nil { t.Fatal(err) } - cmd := exec.Command(exe) - cmd.Env = append(os.Environ(), runAsServer+"=true") + + // get a signal when the server process exits + onExit := make(chan struct{}) + go func() { + cmd.Process.Wait() + close(onExit) + }() + + // send a signal to the server process to terminate it + if runtime.GOOS == "windows" { + // Windows does not support os.Interrupt + session.Close() + } else { + cmd.Process.Signal(os.Interrupt) + } + + // wait for the server to exit + // TODO: use synctest when availble + select { + case <-time.After(5 * time.Second): + t.Fatal("server did not exit after SIGTERM") + case <-onExit: + } +} + +func TestCmdTransport(t *testing.T) { + requireExec(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cmd := createServerCommand(t) client := mcp.NewClient("client", "v0.0.1", nil) session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) if err != nil { - log.Fatal(err) + t.Fatal(err) } got, err := session.CallTool(ctx, &mcp.CallToolParams{ Name: "greet", Arguments: map[string]any{"name": "user"}, }) if err != nil { - log.Fatal(err) + t.Fatal(err) } want := &mcp.CallToolResult{ Content: []mcp.Content{ @@ -80,3 +149,28 @@ func TestCmdTransport(t *testing.T) { t.Fatalf("closing server: %v", err) } } + +func createServerCommand(t *testing.T) *exec.Cmd { + t.Helper() + + exe, err := os.Executable() + if err != nil { + t.Fatal(err) + } + cmd := exec.Command(exe) + cmd.Env = append(os.Environ(), runAsServer+"=true") + + return cmd +} + +func requireExec(t *testing.T) { + t.Helper() + + // Conservatively, limit to major OS where we know that os.Exec is + // supported. + switch runtime.GOOS { + case "darwin", "linux", "windows": + default: + t.Skip("unsupported OS") + } +} diff --git a/mcp/server.go b/mcp/server.go index cd8f808b..0aa054fa 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -404,13 +404,26 @@ func fileResourceHandler(dir string) ResourceHandler { // Run runs the server over the given transport, which must be persistent. // -// Run blocks until the client terminates the connection. +// Run blocks until the client terminates the connection or the provided +// context is cancelled. If the context is cancelled, Run closes the connection. func (s *Server) Run(ctx context.Context, t Transport) error { ss, err := s.Connect(ctx, t) if err != nil { return err } - return ss.Wait() + + ssClosed := make(chan error) + go func() { + ssClosed <- ss.Wait() + }() + + select { + case <-ctx.Done(): + ss.Close() + return ctx.Err() + case err := <-ssClosed: + return err + } } // bind implements the binder[*ServerSession] interface, so that Servers can From c037ba51d27e5ebdf81326248396024d3d05d30d Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 10 Jul 2025 11:06:55 -0400 Subject: [PATCH 013/221] mcp: negotiate protocol version (#112) Negotiate the protocol version properly, and set the header on HTTP transports. We follow the logic of the Typescript SDK. - On initialization, the client sends the latest version it can support, and accepts the version that the server returns if the client supports it. If not, the connection fails. - The server accepts the client's version unless it doesn't support it, in which case it replies with its latest version. Fixes #103. --- mcp/client.go | 21 +++++++++++++++++---- mcp/server.go | 9 +++++---- mcp/shared.go | 10 ++++++++++ mcp/streamable.go | 33 ++++++++++++++++++++++++++------- mcp/streamable_test.go | 10 +++++++++- mcp/transport.go | 6 ++++++ 6 files changed, 73 insertions(+), 16 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 3a935040..512be2cb 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -6,6 +6,7 @@ package mcp import ( "context" + "fmt" "iter" "slices" "sync" @@ -86,6 +87,15 @@ func (c *Client) disconnect(cs *ClientSession) { }) } +// TODO: Consider exporting this type and its field. +type unsupportedProtocolVersionError struct { + version string +} + +func (e unsupportedProtocolVersionError) Error() string { + return fmt.Sprintf("unsupported protocol version: %q", e.version) +} + // Connect begins an MCP session by connecting to a server over the given // transport, and initializing the session. // @@ -106,19 +116,22 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e } params := &InitializeParams{ + ProtocolVersion: latestProtocolVersion, ClientInfo: &implementation{Name: c.name, Version: c.version}, Capabilities: caps, - ProtocolVersion: "2025-03-26", } - // TODO(rfindley): handle protocol negotiation gracefully. If the server - // responds with 2024-11-05, surface that failure to the caller of connect, - // so that they can choose a different transport. res, err := handleSend[*InitializeResult](ctx, cs, methodInitialize, params) if err != nil { _ = cs.Close() return nil, err } + if !slices.Contains(supportedProtocolVersions, res.ProtocolVersion) { + return nil, unsupportedProtocolVersionError{res.ProtocolVersion} + } cs.initializeResult = res + if hc, ok := cs.mcpConn.(httpConnection); ok { + hc.setProtocolVersion(res.ProtocolVersion) + } if err := handleNotify(ctx, cs, notificationInitialized, &InitializedParams{}); err != nil { _ = cs.Close() return nil, err diff --git a/mcp/server.go b/mcp/server.go index 0aa054fa..de14ca06 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -647,10 +647,11 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam ss.mu.Unlock() }() - version := "2025-03-26" // preferred version - switch v := params.ProtocolVersion; v { - case "2024-11-05", "2025-03-26": - version = v + // If we support the client's version, reply with it. Otherwise, reply with our + // latest version. + version := params.ProtocolVersion + if !slices.Contains(supportedProtocolVersions, params.ProtocolVersion) { + version = latestProtocolVersion } return &InitializeResult{ diff --git a/mcp/shared.go b/mcp/shared.go index db871ca8..8a38777e 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -22,6 +22,16 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) +// latestProtocolVersion is the latest protocol version that this version of the SDK supports. +// It is the version that the client sends in the initialization request. +const latestProtocolVersion = "2025-06-18" + +var supportedProtocolVersions = []string{ + latestProtocolVersion, + "2025-03-26", + "2024-11-05", +} + // A MethodHandler handles MCP messages. // For methods, exactly one of the return values must be nil. // For notifications, both must be nil. diff --git a/mcp/streamable.go b/mcp/streamable.go index db3add85..52208948 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -18,6 +18,11 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) +const ( + protocolVersionHeader = "Mcp-Protocol-Version" + sessionIDHeader = "Mcp-Session-Id" +) + // A StreamableHTTPHandler is an http.Handler that serves streamable MCP // sessions, as defined by the [MCP spec]. // @@ -88,7 +93,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } var session *StreamableServerTransport - if id := req.Header.Get("Mcp-Session-Id"); id != "" { + if id := req.Header.Get(sessionIDHeader); id != "" { h.sessionsMu.Lock() session, _ = h.sessions[id] h.sessionsMu.Unlock() @@ -386,7 +391,7 @@ func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *h t.mu.Unlock() } - w.Header().Set("Mcp-Session-Id", t.id) + w.Header().Set(sessionIDHeader, t.id) w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Connection", "keep-alive") @@ -636,12 +641,19 @@ type streamableClientConn struct { closeOnce sync.Once closeErr error - mu sync.Mutex - _sessionID string + mu sync.Mutex + protocolVersion string + _sessionID string // bodies map[*http.Response]io.Closer err error } +func (c *streamableClientConn) setProtocolVersion(s string) { + c.mu.Lock() + defer c.mu.Unlock() + c.protocolVersion = s +} + func (c *streamableClientConn) SessionID() string { c.mu.Lock() defer c.mu.Unlock() @@ -707,8 +719,11 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string if err != nil { return "", err } + if s.protocolVersion != "" { + req.Header.Set(protocolVersionHeader, s.protocolVersion) + } if sessionID != "" { - req.Header.Set("Mcp-Session-Id", sessionID) + req.Header.Set(sessionIDHeader, sessionID) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") @@ -724,7 +739,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string return "", fmt.Errorf("broken session: %v", resp.Status) } - sessionID = resp.Header.Get("Mcp-Session-Id") + sessionID = resp.Header.Get(sessionIDHeader) if resp.Header.Get("Content-Type") == "text/event-stream" { go s.handleSSE(resp) } else { @@ -763,7 +778,11 @@ func (s *streamableClientConn) Close() error { if err != nil { s.closeErr = err } else { - req.Header.Set("Mcp-Session-Id", s._sessionID) + // TODO(jba): confirm that we don't need a lock here, or add locking. + if s.protocolVersion != "" { + req.Header.Set(protocolVersionHeader, s.protocolVersion) + } + req.Header.Set(sessionIDHeader, s._sessionID) if _, err := s.client.Do(req); err != nil { s.closeErr = err } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 8925b3da..3329caea 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -36,7 +36,9 @@ func TestStreamableTransports(t *testing.T) { // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a // cookie-checking middleware. handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + var header http.Header httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + header = r.Header cookie, err := r.Cookie("test-cookie") if err != nil { t.Errorf("missing cookie: %v", err) @@ -72,6 +74,9 @@ func TestStreamableTransports(t *testing.T) { if sid == "" { t.Error("empty session ID") } + if g, w := session.mcpConn.(*streamableClientConn).protocolVersion, latestProtocolVersion; g != w { + t.Fatalf("got protocol version %q, want %q", g, w) + } // 4. The client calls the "greet" tool. params := &CallToolParams{ Name: "greet", @@ -84,6 +89,9 @@ func TestStreamableTransports(t *testing.T) { if g := session.ID(); g != sid { t.Errorf("session ID: got %q, want %q", g, sid) } + if g, w := header.Get(protocolVersionHeader), latestProtocolVersion; g != w { + t.Errorf("got protocol version header %q, want %q", g, w) + } // 5. Verify that the correct response is received. want := &CallToolResult{ @@ -154,7 +162,7 @@ func TestStreamableServerTransport(t *testing.T) { Resources: &resourceCapabilities{ListChanged: true}, Tools: &toolCapabilities{ListChanged: true}, }, - ProtocolVersion: "2025-03-26", + ProtocolVersion: latestProtocolVersion, ServerInfo: &implementation{Name: "testServer", Version: "v1.0.0"}, }, nil) initializedMsg := req(0, "initialized", &InitializedParams{}) diff --git a/mcp/transport.go b/mcp/transport.go index 85bfaf65..f0b81650 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -53,6 +53,12 @@ type Connection interface { SessionID() string } +// An httpConnection is a [Connection] that runs over HTTP. +type httpConnection interface { + Connection + setProtocolVersion(string) +} + // A StdioTransport is a [Transport] that communicates over stdin/stdout using // newline-delimited JSON. type StdioTransport struct { From 2c577e564b3d48bb634bed0c89cd1052ce1e9051 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 10 Jul 2025 13:51:23 -0400 Subject: [PATCH 014/221] jsonschema: support "jsonschema" struct tags (#101) Struct fields may have the "jsonschema" struct tag, which is used as the description of the property. Fixes #47. --- jsonschema/infer.go | 22 +++++++++- jsonschema/infer_test.go | 87 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 99 insertions(+), 10 deletions(-) diff --git a/jsonschema/infer.go b/jsonschema/infer.go index d1c1a5fb..ae441291 100644 --- a/jsonschema/infer.go +++ b/jsonschema/infer.go @@ -9,6 +9,7 @@ package jsonschema import ( "fmt" "reflect" + "regexp" "github.com/modelcontextprotocol/go-sdk/internal/util" ) @@ -36,8 +37,12 @@ import ( // - complex numbers // - unsafe pointers // -// The types must not have cycles. // It will return an error if there is a cycle in the types. +// +// For recognizes struct field tags named "jsonschema". +// A jsonschema tag on a field is used as the description for the corresponding property. +// For future compatibility, descriptions must not start with "WORD=", where WORD is a +// sequence of non-whitespace characters. func For[T any]() (*Schema, error) { // TODO: consider skipping incompatible fields, instead of failing. seen := make(map[reflect.Type]bool) @@ -126,10 +131,20 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { if s.Properties == nil { s.Properties = make(map[string]*Schema) } - s.Properties[info.Name], err = forType(field.Type, seen) + fs, err := forType(field.Type, seen) if err != nil { return nil, err } + if tag, ok := field.Tag.Lookup("jsonschema"); ok { + if tag == "" { + return nil, fmt.Errorf("empty jsonschema tag on struct field %s.%s", t, field.Name) + } + if disallowedPrefixRegexp.MatchString(tag) { + return nil, fmt.Errorf("tag must not begin with 'WORD=': %q", tag) + } + fs.Description = tag + } + s.Properties[info.Name] = fs if !info.Settings["omitempty"] && !info.Settings["omitzero"] { s.Required = append(s.Required, info.Name) } @@ -144,3 +159,6 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { } return s, nil } + +// Disallow jsonschema tag values beginning "WORD=", for future expansion. +var disallowedPrefixRegexp = regexp.MustCompile("^[^ \t\n]*=") diff --git a/jsonschema/infer_test.go b/jsonschema/infer_test.go index 0b1b769a..106e5375 100644 --- a/jsonschema/infer_test.go +++ b/jsonschema/infer_test.go @@ -5,6 +5,7 @@ package jsonschema_test import ( + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -20,8 +21,13 @@ func forType[T any]() *jsonschema.Schema { return s } -func TestForType(t *testing.T) { +func TestFor(t *testing.T) { type schema = jsonschema.Schema + + type S struct { + B int `jsonschema:"bdesc"` + } + tests := []struct { name string got *jsonschema.Schema @@ -44,9 +50,9 @@ func TestForType(t *testing.T) { { "struct", forType[struct { - F int `json:"f"` + F int `json:"f" jsonschema:"fdesc"` G []float64 - P *bool + P *bool `jsonschema:"pdesc"` Skip string `json:"-"` NoSkip string `json:",omitempty"` unexported float64 @@ -55,13 +61,13 @@ func TestForType(t *testing.T) { &schema{ Type: "object", Properties: map[string]*schema{ - "f": {Type: "integer"}, + "f": {Type: "integer", Description: "fdesc"}, "G": {Type: "array", Items: &schema{Type: "number"}}, - "P": {Types: []string{"null", "boolean"}}, + "P": {Types: []string{"null", "boolean"}, Description: "pdesc"}, "NoSkip": {Type: "string"}, }, Required: []string{"f", "G", "P"}, - AdditionalProperties: &jsonschema.Schema{Not: &jsonschema.Schema{}}, + AdditionalProperties: falseSchema(), }, }, { @@ -74,7 +80,37 @@ func TestForType(t *testing.T) { "Y": {Type: "integer"}, }, Required: []string{"X", "Y"}, - AdditionalProperties: &jsonschema.Schema{Not: &jsonschema.Schema{}}, + AdditionalProperties: falseSchema(), + }, + }, + { + "nested and embedded", + forType[struct { + A S + S + }](), + &schema{ + Type: "object", + Properties: map[string]*schema{ + "A": { + Type: "object", + Properties: map[string]*schema{ + "B": {Type: "integer", Description: "bdesc"}, + }, + Required: []string{"B"}, + AdditionalProperties: falseSchema(), + }, + "S": { + Type: "object", + Properties: map[string]*schema{ + "B": {Type: "integer", Description: "bdesc"}, + }, + Required: []string{"B"}, + AdditionalProperties: falseSchema(), + }, + }, + Required: []string{"A", "S"}, + AdditionalProperties: falseSchema(), }, }, } @@ -92,6 +128,38 @@ func TestForType(t *testing.T) { } } +func forErr[T any]() error { + _, err := jsonschema.For[T]() + return err +} + +func TestForErrors(t *testing.T) { + type ( + s1 struct { + Empty int `jsonschema:""` + } + s2 struct { + Bad int `jsonschema:"$foo=1,bar"` + } + ) + + for _, tt := range []struct { + got error + want string + }{ + {forErr[map[int]int](), "unsupported map key type"}, + {forErr[s1](), "empty jsonschema tag"}, + {forErr[s2](), "must not begin with"}, + {forErr[func()](), "unsupported"}, + } { + if tt.got == nil { + t.Errorf("got nil, want error containing %q", tt.want) + } else if !strings.Contains(tt.got.Error(), tt.want) { + t.Errorf("got %q\nwant it to contain %q", tt.got, tt.want) + } + } +} + func TestForWithMutation(t *testing.T) { // This test ensures that the cached schema is not mutated when the caller // mutates the returned schema. @@ -172,7 +240,6 @@ func TestForWithCycle(t *testing.T) { } for _, test := range tests { - test := test // prevent loop shadowing t.Run(test.name, func(t *testing.T) { err := test.fn() if test.shouldErr && err == nil { @@ -184,3 +251,7 @@ func TestForWithCycle(t *testing.T) { }) } } + +func falseSchema() *jsonschema.Schema { + return &jsonschema.Schema{Not: &jsonschema.Schema{}} +} From bfa5e30cd12518da8a6aea7e3760d4b82d37964c Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 10 Jul 2025 14:33:23 -0400 Subject: [PATCH 015/221] jsonrpc: package with JSON-RPC symbols (#117) Move the aliases to the internal/jsonrpc2 package into their own package. Fixes #115. --- jsonrpc/jsonrpc.go | 20 ++++++++++++++ mcp/client.go | 3 +- mcp/server.go | 3 +- mcp/shared.go | 3 +- mcp/sse.go | 13 +++++---- mcp/streamable.go | 37 +++++++++++++------------ mcp/streamable_test.go | 63 +++++++++++++++++++++--------------------- mcp/transport.go | 54 +++++++++++++++--------------------- mcp/transport_test.go | 11 ++++---- 9 files changed, 112 insertions(+), 95 deletions(-) create mode 100644 jsonrpc/jsonrpc.go diff --git a/jsonrpc/jsonrpc.go b/jsonrpc/jsonrpc.go new file mode 100644 index 00000000..f175e597 --- /dev/null +++ b/jsonrpc/jsonrpc.go @@ -0,0 +1,20 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package jsonrpc exposes part of a JSON-RPC v2 implementation +// for use by mcp transport authors. +package jsonrpc + +import "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + +type ( + // ID is a JSON-RPC request ID. + ID = jsonrpc2.ID + // Message is a JSON-RPC message. + Message = jsonrpc2.Message + // Request is a JSON-RPC request. + Request = jsonrpc2.Request + // Response is a JSON-RPC response. + Response = jsonrpc2.Response +) diff --git a/mcp/client.go b/mcp/client.go index 512be2cb..40d3c792 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -13,6 +13,7 @@ import ( "time" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) // A Client is an MCP client, which may be connected to an MCP server @@ -301,7 +302,7 @@ func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { return clientMethodInfos } -func (cs *ClientSession) handle(ctx context.Context, req *JSONRPCRequest) (any, error) { +func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { return handleReceive(ctx, cs, req) } diff --git a/mcp/server.go b/mcp/server.go index de14ca06..f9b76539 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -21,6 +21,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/util" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) const DefaultPageSize = 1000 @@ -610,7 +611,7 @@ func (ss *ServerSession) receivingMethodHandler() methodHandler { func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn } // handle invokes the method described by the given JSON RPC request. -func (ss *ServerSession) handle(ctx context.Context, req *JSONRPCRequest) (any, error) { +func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { ss.mu.Lock() initialized := ss.initialized ss.mu.Unlock() diff --git a/mcp/shared.go b/mcp/shared.go index 8a38777e..fef20946 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -20,6 +20,7 @@ import ( "time" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) // latestProtocolVersion is the latest protocol version that this version of the SDK supports. @@ -121,7 +122,7 @@ func defaultReceivingMethodHandler[S Session](ctx context.Context, session S, me return info.handleMethod.(MethodHandler[S])(ctx, session, method, params) } -func handleReceive[S Session](ctx context.Context, session S, req *JSONRPCRequest) (Result, error) { +func handleReceive[S Session](ctx context.Context, session S, req *jsonrpc.Request) (Result, error) { info, ok := session.receivingMethodInfos()[req.Method] if !ok { return nil, jsonrpc2.ErrNotHandled diff --git a/mcp/sse.go b/mcp/sse.go index d1b52599..f0d7b34c 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -18,6 +18,7 @@ import ( "sync" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) // This file implements support for SSE (HTTP with server-sent events) @@ -111,7 +112,7 @@ func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler { // - Close terminates the hanging GET. type SSEServerTransport struct { endpoint string - incoming chan JSONRPCMessage // queue of incoming messages; never closed + incoming chan jsonrpc.Message // queue of incoming messages; never closed // We must guard both pushes to the incoming queue and writes to the response // writer, because incoming POST requests are arbitrarily concurrent and we @@ -138,7 +139,7 @@ func NewSSEServerTransport(endpoint string, w http.ResponseWriter) *SSEServerTra return &SSEServerTransport{ endpoint: endpoint, w: w, - incoming: make(chan JSONRPCMessage, 100), + incoming: make(chan jsonrpc.Message, 100), done: make(chan struct{}), } } @@ -267,7 +268,7 @@ type sseServerConn struct { func (s sseServerConn) SessionID() string { return "" } // Read implements jsonrpc2.Reader. -func (s sseServerConn) Read(ctx context.Context) (JSONRPCMessage, error) { +func (s sseServerConn) Read(ctx context.Context) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -279,7 +280,7 @@ func (s sseServerConn) Read(ctx context.Context) (JSONRPCMessage, error) { } // Write implements jsonrpc2.Writer. -func (s sseServerConn) Write(ctx context.Context, msg JSONRPCMessage) error { +func (s sseServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { if ctx.Err() != nil { return ctx.Err() } @@ -532,7 +533,7 @@ func (c *sseClientConn) isDone() bool { return c.closed } -func (c *sseClientConn) Read(ctx context.Context) (JSONRPCMessage, error) { +func (c *sseClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -553,7 +554,7 @@ func (c *sseClientConn) Read(ctx context.Context) (JSONRPCMessage, error) { } } -func (c *sseClientConn) Write(ctx context.Context, msg JSONRPCMessage) error { +func (c *sseClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { data, err := jsonrpc2.EncodeMessage(msg) if err != nil { return err diff --git a/mcp/streamable.go b/mcp/streamable.go index 52208948..11d70a38 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -16,6 +16,7 @@ import ( "sync/atomic" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) const ( @@ -157,12 +158,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque func NewStreamableServerTransport(sessionID string) *StreamableServerTransport { return &StreamableServerTransport{ id: sessionID, - incoming: make(chan JSONRPCMessage, 10), + incoming: make(chan jsonrpc.Message, 10), done: make(chan struct{}), outgoingMessages: make(map[streamID][]*streamableMsg), signals: make(map[streamID]chan struct{}), - requestStreams: make(map[JSONRPCID]streamID), - streamRequests: make(map[streamID]map[JSONRPCID]struct{}), + requestStreams: make(map[jsonrpc.ID]streamID), + streamRequests: make(map[streamID]map[jsonrpc.ID]struct{}), } } @@ -176,7 +177,7 @@ type StreamableServerTransport struct { nextStreamID atomic.Int64 // incrementing next stream ID id string - incoming chan JSONRPCMessage // messages from the client to the server + incoming chan jsonrpc.Message // messages from the client to the server mu sync.Mutex @@ -226,7 +227,7 @@ type StreamableServerTransport struct { // Lifecycle: requestStreams persists for the duration of the session. // // TODO(rfindley): clean up once requests are handled. - requestStreams map[JSONRPCID]streamID + requestStreams map[jsonrpc.ID]streamID // streamRequests tracks the set of unanswered incoming RPCs for each logical // stream. @@ -237,7 +238,7 @@ type StreamableServerTransport struct { // Lifecycle: streamRequests values persist as until the requests have been // replied to by the server. Notably, NOT until they are sent to an HTTP // response, as delivery is not guaranteed. - streamRequests map[streamID]map[JSONRPCID]struct{} + streamRequests map[streamID]map[jsonrpc.ID]struct{} } type streamID int64 @@ -271,7 +272,7 @@ func (s *StreamableServerTransport) Connect(context.Context) (Connection, error) // 2. Expose a 'HandlerTransport' interface that allows transports to provide // a handler middleware, so that we don't hard-code this behavior in // ServerSession.handle. -// 3. Add a `func ForRequest(context.Context) JSONRPCID` accessor that lets +// 3. Add a `func ForRequest(context.Context) jsonrpc.ID` accessor that lets // any transport access the incoming request ID. // // For now, by giving only the StreamableServerTransport access to the request @@ -340,9 +341,9 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) return } - requests := make(map[JSONRPCID]struct{}) + requests := make(map[jsonrpc.ID]struct{}) for _, msg := range incoming { - if req, ok := msg.(*JSONRPCRequest); ok && req.ID.IsValid() { + if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() { requests[req.ID] = struct{}{} } } @@ -352,7 +353,7 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R signal := make(chan struct{}, 1) t.mu.Lock() if len(requests) > 0 { - t.streamRequests[id] = make(map[JSONRPCID]struct{}) + t.streamRequests[id] = make(map[jsonrpc.ID]struct{}) } for reqID := range requests { t.requestStreams[reqID] = id @@ -484,7 +485,7 @@ func parseEventID(eventID string) (sid streamID, idx int, ok bool) { } // Read implements the [Connection] interface. -func (t *StreamableServerTransport) Read(ctx context.Context) (JSONRPCMessage, error) { +func (t *StreamableServerTransport) Read(ctx context.Context) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -499,10 +500,10 @@ func (t *StreamableServerTransport) Read(ctx context.Context) (JSONRPCMessage, e } // Write implements the [Connection] interface. -func (t *StreamableServerTransport) Write(ctx context.Context, msg JSONRPCMessage) error { +func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Message) error { // Find the incoming request that this write relates to, if any. - var forRequest, replyTo JSONRPCID - if resp, ok := msg.(*JSONRPCResponse); ok { + var forRequest, replyTo jsonrpc.ID + if resp, ok := msg.(*jsonrpc.Response); ok { // If the message is a response, it relates to its request (of course). forRequest = resp.ID replyTo = resp.ID @@ -511,7 +512,7 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg JSONRPCMessag // ongoing request. This may not be the case if the request way made with // an unrelated context. if v := ctx.Value(idContextKey{}); v != nil { - forRequest = v.(JSONRPCID) + forRequest = v.(jsonrpc.ID) } } @@ -661,7 +662,7 @@ func (c *streamableClientConn) SessionID() string { } // Read implements the [Connection] interface. -func (s *streamableClientConn) Read(ctx context.Context) (JSONRPCMessage, error) { +func (s *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -673,7 +674,7 @@ func (s *streamableClientConn) Read(ctx context.Context) (JSONRPCMessage, error) } // Write implements the [Connection] interface. -func (s *streamableClientConn) Write(ctx context.Context, msg JSONRPCMessage) error { +func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { s.mu.Lock() if s.err != nil { s.mu.Unlock() @@ -709,7 +710,7 @@ func (s *streamableClientConn) Write(ctx context.Context, msg JSONRPCMessage) er return nil } -func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg JSONRPCMessage) (string, error) { +func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) { data, err := jsonrpc2.EncodeMessage(msg) if err != nil { return "", err diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 3329caea..412d2e1d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -22,6 +22,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) func TestStreamableTransports(t *testing.T) { @@ -126,16 +127,16 @@ func TestStreamableServerTransport(t *testing.T) { // Redundant with OnRequest: all OnRequest steps are asynchronous. Async bool - Method string // HTTP request method - Send []JSONRPCMessage // messages to send - CloseAfter int // if nonzero, close after receiving this many messages - StatusCode int // expected status code - Recv []JSONRPCMessage // expected messages to receive + Method string // HTTP request method + Send []jsonrpc.Message // messages to send + CloseAfter int // if nonzero, close after receiving this many messages + StatusCode int // expected status code + Recv []jsonrpc.Message // expected messages to receive } // JSON-RPC message constructors. - req := func(id int64, method string, params any) *JSONRPCRequest { - r := &JSONRPCRequest{ + req := func(id int64, method string, params any) *jsonrpc.Request { + r := &jsonrpc.Request{ Method: method, Params: mustMarshal(t, params), } @@ -144,8 +145,8 @@ func TestStreamableServerTransport(t *testing.T) { } return r } - resp := func(id int64, result any, err error) *JSONRPCResponse { - return &JSONRPCResponse{ + resp := func(id int64, result any, err error) *jsonrpc.Response { + return &jsonrpc.Response{ ID: jsonrpc2.Int64ID(id), Result: mustMarshal(t, result), Error: err, @@ -168,13 +169,13 @@ func TestStreamableServerTransport(t *testing.T) { initializedMsg := req(0, "initialized", &InitializedParams{}) initialize := step{ Method: "POST", - Send: []JSONRPCMessage{initReq}, + Send: []jsonrpc.Message{initReq}, StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{initResp}, + Recv: []jsonrpc.Message{initResp}, } initialized := step{ Method: "POST", - Send: []JSONRPCMessage{initializedMsg}, + Send: []jsonrpc.Message{initializedMsg}, StatusCode: http.StatusAccepted, } @@ -190,9 +191,9 @@ func TestStreamableServerTransport(t *testing.T) { initialized, { Method: "POST", - Send: []JSONRPCMessage{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + Send: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{resp(2, &CallToolResult{}, nil)}, + Recv: []jsonrpc.Message{resp(2, &CallToolResult{}, nil)}, }, }, }, @@ -209,11 +210,11 @@ func TestStreamableServerTransport(t *testing.T) { initialized, { Method: "POST", - Send: []JSONRPCMessage{ + Send: []jsonrpc.Message{ req(2, "tools/call", &CallToolParams{Name: "tool"}), }, StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{ + Recv: []jsonrpc.Message{ req(0, "notifications/progress", &ProgressNotificationParams{}), resp(2, &CallToolResult{}, nil), }, @@ -234,18 +235,18 @@ func TestStreamableServerTransport(t *testing.T) { { Method: "POST", OnRequest: 1, - Send: []JSONRPCMessage{ + Send: []jsonrpc.Message{ resp(1, &ListRootsResult{}, nil), }, StatusCode: http.StatusAccepted, }, { Method: "POST", - Send: []JSONRPCMessage{ + Send: []jsonrpc.Message{ req(2, "tools/call", &CallToolParams{Name: "tool"}), }, StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{ + Recv: []jsonrpc.Message{ req(1, "roots/list", &ListRootsParams{}), resp(2, &CallToolResult{}, nil), }, @@ -275,7 +276,7 @@ func TestStreamableServerTransport(t *testing.T) { { Method: "POST", OnRequest: 1, - Send: []JSONRPCMessage{ + Send: []jsonrpc.Message{ resp(1, &ListRootsResult{}, nil), }, StatusCode: http.StatusAccepted, @@ -285,18 +286,18 @@ func TestStreamableServerTransport(t *testing.T) { Async: true, StatusCode: http.StatusOK, CloseAfter: 2, - Recv: []JSONRPCMessage{ + Recv: []jsonrpc.Message{ req(0, "notifications/progress", &ProgressNotificationParams{}), req(1, "roots/list", &ListRootsParams{}), }, }, { Method: "POST", - Send: []JSONRPCMessage{ + Send: []jsonrpc.Message{ req(2, "tools/call", &CallToolParams{Name: "tool"}), }, StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{ + Recv: []jsonrpc.Message{ resp(2, &CallToolResult{}, nil), }, }, @@ -315,9 +316,9 @@ func TestStreamableServerTransport(t *testing.T) { }, { Method: "POST", - Send: []JSONRPCMessage{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + Send: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{resp(2, nil, &jsonrpc2.WireError{ + Recv: []jsonrpc.Message{resp(2, nil, &jsonrpc2.WireError{ Message: `method "tools/call" is invalid during session initialization`, })}, }, @@ -344,7 +345,7 @@ func TestStreamableServerTransport(t *testing.T) { httpServer := httptest.NewServer(handler) defer httpServer.Close() - // blocks records request blocks by JSONRPC ID. + // blocks records request blocks by jsonrpc. ID. // // When an OnRequest step is encountered, it waits on the corresponding // block. When a request with that ID is received, the block is closed. @@ -382,8 +383,8 @@ func TestStreamableServerTransport(t *testing.T) { // Collect messages received during this request, unblock other steps // when requests are received. - var got []JSONRPCMessage - out := make(chan JSONRPCMessage) + var got []jsonrpc.Message + out := make(chan jsonrpc.Message) // Cancel the step if we encounter a request that isn't going to be // handled. ctx, cancel := context.WithCancel(context.Background()) @@ -394,7 +395,7 @@ func TestStreamableServerTransport(t *testing.T) { defer wg.Done() for m := range out { - if req, ok := m.(*JSONRPCRequest); ok && req.ID.IsValid() { + if req, ok := m.(*jsonrpc.Request); ok && req.ID.IsValid() { // Encountered a server->client request. We should have a // response queued. Otherwise, we may deadlock. mu.Lock() @@ -427,7 +428,7 @@ func TestStreamableServerTransport(t *testing.T) { } wg.Wait() - transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id JSONRPCID) any { return id.Raw() }) + transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) if diff := cmp.Diff(step.Recv, got, transform); diff != "" { t.Errorf("received unexpected messages (-want +got):\n%s", diff) } @@ -469,7 +470,7 @@ func TestStreamableServerTransport(t *testing.T) { // Returns the sessionID and http status code from the response. If an error is // returned, sessionID and status code may still be set if the error occurs // after the response headers have been received. -func streamingRequest(ctx context.Context, serverURL, sessionID, method string, in []JSONRPCMessage, out chan<- JSONRPCMessage) (string, int, error) { +func streamingRequest(ctx context.Context, serverURL, sessionID, method string, in []jsonrpc.Message, out chan<- jsonrpc.Message) (string, int, error) { defer close(out) var body []byte diff --git a/mcp/transport.go b/mcp/transport.go index f0b81650..a7de5061 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -16,6 +16,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/xcontext" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) // ErrConnectionClosed is returned when sending a message to a connection that @@ -34,21 +35,10 @@ type Transport interface { Connect(ctx context.Context) (Connection, error) } -type ( - // JSONRPCID is a JSON-RPC request ID. - JSONRPCID = jsonrpc2.ID - // JSONRPCMessage is a JSON-RPC message. - JSONRPCMessage = jsonrpc2.Message - // JSONRPCRequest is a JSON-RPC request. - JSONRPCRequest = jsonrpc2.Request - // JSONRPCResponse is a JSON-RPC response. - JSONRPCResponse = jsonrpc2.Response -) - // A Connection is a logical bidirectional JSON-RPC connection. type Connection interface { - Read(context.Context) (JSONRPCMessage, error) - Write(context.Context, JSONRPCMessage) error + Read(context.Context) (jsonrpc.Message, error) + Write(context.Context, jsonrpc.Message) error Close() error // may be called concurrently by both peers SessionID() string } @@ -100,7 +90,7 @@ type binder[T handler] interface { } type handler interface { - handle(ctx context.Context, req *JSONRPCRequest) (any, error) + handle(ctx context.Context, req *jsonrpc.Request) (any, error) setConn(Connection) } @@ -143,7 +133,7 @@ type canceller struct { } // Preempt implements jsonrpc2.Preempter. -func (c *canceller) Preempt(ctx context.Context, req *JSONRPCRequest) (result any, err error) { +func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result any, err error) { if req.Method == "notifications/cancelled" { var params CancelledParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { @@ -212,7 +202,7 @@ type loggingConn struct { func (c *loggingConn) SessionID() string { return c.delegate.SessionID() } // loggingReader is a stream middleware that logs incoming messages. -func (s *loggingConn) Read(ctx context.Context) (JSONRPCMessage, error) { +func (s *loggingConn) Read(ctx context.Context) (jsonrpc.Message, error) { msg, err := s.delegate.Read(ctx) if err != nil { fmt.Fprintf(s.w, "read error: %v", err) @@ -227,7 +217,7 @@ func (s *loggingConn) Read(ctx context.Context) (JSONRPCMessage, error) { } // loggingWriter is a stream middleware that logs outgoing messages. -func (s *loggingConn) Write(ctx context.Context, msg JSONRPCMessage) error { +func (s *loggingConn) Write(ctx context.Context, msg jsonrpc.Message) error { err := s.delegate.Write(ctx, msg) if err != nil { fmt.Fprintf(s.w, "write error: %v", err) @@ -265,7 +255,7 @@ func (r rwc) Close() error { } // An ioConn is a transport that delimits messages with newlines across -// a bidirectional stream, and supports JSONRPC2 message batching. +// a bidirectional stream, and supports jsonrpc.2 message batching. // // See https://github.com/ndjson/ndjson-spec for discussion of newline // delimited JSON. @@ -277,11 +267,11 @@ type ioConn struct { // If outgoiBatch has a positive capacity, it will be used to batch requests // and notifications before sending. - outgoingBatch []JSONRPCMessage + outgoingBatch []jsonrpc.Message // Unread messages in the last batch. Since reads are serialized, there is no // need to guard here. - queue []JSONRPCMessage + queue []jsonrpc.Message // batches correlate incoming requests to the batch in which they arrived. // Since writes may be concurrent to reads, we need to guard this with a mutex. @@ -325,7 +315,7 @@ func (t *ioConn) addBatch(batch *msgBatch) error { // The second result reports whether resp was part of a batch. If this is true, // the first result is nil if the batch is still incomplete, or the full set of // batch responses if resp completed the batch. -func (t *ioConn) updateBatch(resp *JSONRPCResponse) ([]*JSONRPCResponse, bool) { +func (t *ioConn) updateBatch(resp *jsonrpc.Response) ([]*jsonrpc.Response, bool) { t.batchMu.Lock() defer t.batchMu.Unlock() @@ -345,9 +335,9 @@ func (t *ioConn) updateBatch(resp *JSONRPCResponse) ([]*JSONRPCResponse, bool) { return nil, false } -// A msgBatch records information about an incoming batch of JSONRPC2 calls. +// A msgBatch records information about an incoming batch of jsonrpc.2 calls. // -// The JSONRPC2 spec (https://www.jsonrpc.org/specification#batch) says: +// The jsonrpc.2 spec (https://www.jsonrpc.org/specification#batch) says: // // "The Server should respond with an Array containing the corresponding // Response objects, after all of the batch Request objects have been @@ -360,14 +350,14 @@ func (t *ioConn) updateBatch(resp *JSONRPCResponse) ([]*JSONRPCResponse, bool) { // When there are no unresolved calls, the response payload is sent. type msgBatch struct { unresolved map[jsonrpc2.ID]int - responses []*JSONRPCResponse + responses []*jsonrpc.Response } -func (t *ioConn) Read(ctx context.Context) (JSONRPCMessage, error) { +func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { return t.read(ctx, t.in) } -func (t *ioConn) read(ctx context.Context, in *json.Decoder) (JSONRPCMessage, error) { +func (t *ioConn) read(ctx context.Context, in *json.Decoder) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -392,7 +382,7 @@ func (t *ioConn) read(ctx context.Context, in *json.Decoder) (JSONRPCMessage, er if batch { var respBatch *msgBatch // track incoming requests in the batch for _, msg := range msgs { - if req, ok := msg.(*JSONRPCRequest); ok { + if req, ok := msg.(*jsonrpc.Request); ok { if respBatch == nil { respBatch = &msgBatch{ unresolved: make(map[jsonrpc2.ID]int), @@ -417,7 +407,7 @@ func (t *ioConn) read(ctx context.Context, in *json.Decoder) (JSONRPCMessage, er // readBatch reads batch data, which may be either a single JSON-RPC message, // or an array of JSON-RPC messages. -func readBatch(data []byte) (msgs []JSONRPCMessage, isBatch bool, _ error) { +func readBatch(data []byte) (msgs []jsonrpc.Message, isBatch bool, _ error) { // Try to read an array of messages first. var rawBatch []json.RawMessage if err := json.Unmarshal(data, &rawBatch); err == nil { @@ -435,10 +425,10 @@ func readBatch(data []byte) (msgs []JSONRPCMessage, isBatch bool, _ error) { } // Try again with a single message. msg, err := jsonrpc2.DecodeMessage(data) - return []JSONRPCMessage{msg}, false, err + return []jsonrpc.Message{msg}, false, err } -func (t *ioConn) Write(ctx context.Context, msg JSONRPCMessage) error { +func (t *ioConn) Write(ctx context.Context, msg jsonrpc.Message) error { select { case <-ctx.Done(): return ctx.Err() @@ -449,7 +439,7 @@ func (t *ioConn) Write(ctx context.Context, msg JSONRPCMessage) error { // check that first. Otherwise, it is a request or notification, and we may // want to collect it into a batch before sending, if we're configured to use // outgoing batches. - if resp, ok := msg.(*JSONRPCResponse); ok { + if resp, ok := msg.(*jsonrpc.Response); ok { if batch, ok := t.updateBatch(resp); ok { if len(batch) > 0 { data, err := marshalMessages(batch) @@ -489,7 +479,7 @@ func (t *ioConn) Close() error { return t.rwc.Close() } -func marshalMessages[T JSONRPCMessage](msgs []T) ([]byte, error) { +func marshalMessages[T jsonrpc.Message](msgs []T) ([]byte, error) { var rawMsgs []json.RawMessage for _, msg := range msgs { raw, err := jsonrpc2.EncodeMessage(msg) diff --git a/mcp/transport_test.go b/mcp/transport_test.go index db18a352..c63b84ee 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) func TestBatchFraming(t *testing.T) { @@ -22,10 +23,10 @@ func TestBatchFraming(t *testing.T) { r, w := io.Pipe() tport := newIOConn(rwc{r, w}) - tport.outgoingBatch = make([]JSONRPCMessage, 0, 2) + tport.outgoingBatch = make([]jsonrpc.Message, 0, 2) // Read the two messages into a channel, for easy testing later. - read := make(chan JSONRPCMessage) + read := make(chan jsonrpc.Message) go func() { for range 2 { msg, _ := tport.Read(ctx) @@ -34,7 +35,7 @@ func TestBatchFraming(t *testing.T) { }() // The first write should not yet be observed by the reader. - tport.Write(ctx, &JSONRPCRequest{ID: jsonrpc2.Int64ID(1), Method: "test"}) + tport.Write(ctx, &jsonrpc.Request{ID: jsonrpc2.Int64ID(1), Method: "test"}) select { case got := <-read: t.Fatalf("after one write, got message %v", got) @@ -42,10 +43,10 @@ func TestBatchFraming(t *testing.T) { } // ...but the second write causes both messages to be observed. - tport.Write(ctx, &JSONRPCRequest{ID: jsonrpc2.Int64ID(2), Method: "test"}) + tport.Write(ctx, &jsonrpc.Request{ID: jsonrpc2.Int64ID(2), Method: "test"}) for _, want := range []int64{1, 2} { got := <-read - if got := got.(*JSONRPCRequest).ID.Raw(); got != want { + if got := got.(*jsonrpc.Request).ID.Raw(); got != want { t.Errorf("got message #%d, want #%d", got, want) } } From b4febf129b4d45e1ad9d19e2b5168bf88e7cbdd6 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 10 Jul 2025 15:52:11 -0400 Subject: [PATCH 016/221] mcp: advertise a capability if and only if added (#93) If tools are added to a Server when Run is called, it will send toolCapabilities to the client. Otherwise it won't. Ditto for prompts and resources. For #56. --- mcp/features.go | 3 ++ mcp/server.go | 41 +++++++++++++------- mcp/server_test.go | 88 ++++++++++++++++++++++++++++++++++++++++++ mcp/streamable_test.go | 2 - 4 files changed, 118 insertions(+), 16 deletions(-) diff --git a/mcp/features.go b/mcp/features.go index 1777b33f..43c58854 100644 --- a/mcp/features.go +++ b/mcp/features.go @@ -66,6 +66,9 @@ func (s *featureSet[T]) get(uid string) (T, bool) { return t, ok } +// len returns the number of features in the set. +func (s *featureSet[T]) len() int { return len(s.features) } + // all returns an iterator over of all the features in the set // sorted by unique ID. func (s *featureSet[T]) all() iter.Seq[T] { diff --git a/mcp/server.go b/mcp/server.go index f9b76539..16cc8d6b 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -204,6 +204,26 @@ func (s *Server) RemoveResourceTemplates(uriTemplates ...string) { func() bool { return s.resourceTemplates.remove(uriTemplates...) }) } +func (s *Server) capabilities() *serverCapabilities { + s.mu.Lock() + defer s.mu.Unlock() + + caps := &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + } + if s.tools.len() > 0 { + caps.Tools = &toolCapabilities{ListChanged: true} + } + if s.prompts.len() > 0 { + caps.Prompts = &promptCapabilities{ListChanged: true} + } + if s.resources.len() > 0 || s.resourceTemplates.len() > 0 { + caps.Resources = &resourceCapabilities{ListChanged: true} + } + return caps +} + func (s *Server) complete(ctx context.Context, ss *ServerSession, params *CompleteParams) (Result, error) { if s.opts.CompletionHandler == nil { return nil, jsonrpc2.ErrMethodNotFound @@ -407,6 +427,11 @@ func fileResourceHandler(dir string) ResourceHandler { // // Run blocks until the client terminates the connection or the provided // context is cancelled. If the context is cancelled, Run closes the connection. +// +// If tools have been added to the server before this call, then the server will +// advertise the capability for tools, including the ability to send list-changed notifications. +// If no tools have been added, the server will not have the tool capability. +// The same goes for other features like prompts and resources. func (s *Server) Run(ctx context.Context, t Transport) error { ss, err := s.Connect(ctx, t) if err != nil { @@ -659,20 +684,8 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam // TODO(rfindley): alter behavior when falling back to an older version: // reject unsupported features. ProtocolVersion: version, - Capabilities: &serverCapabilities{ - Completions: &completionCapabilities{}, - Prompts: &promptCapabilities{ - ListChanged: true, - }, - Tools: &toolCapabilities{ - ListChanged: true, - }, - Resources: &resourceCapabilities{ - ListChanged: true, - }, - Logging: &loggingCapabilities{}, - }, - Instructions: ss.server.opts.Instructions, + Capabilities: ss.server.capabilities(), + Instructions: ss.server.opts.Instructions, ServerInfo: &implementation{ Name: ss.server.name, Version: ss.server.version, diff --git a/mcp/server_test.go b/mcp/server_test.go index 19701f39..cc94003c 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -227,3 +227,91 @@ func TestServerPaginateVariousPageSizes(t *testing.T) { } } } + +func TestServerCapabilities(t *testing.T) { + testCases := []struct { + name string + configureServer func(s *Server) + wantCapabilities *serverCapabilities + }{ + { + name: "No capabilities", + configureServer: func(s *Server) {}, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + }, + }, + { + name: "With prompts", + configureServer: func(s *Server) { + s.AddPrompt(&Prompt{Name: "p"}, nil) + }, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Prompts: &promptCapabilities{ListChanged: true}, + }, + }, + { + name: "With resources", + configureServer: func(s *Server) { + s.AddResource(&Resource{URI: "file:///r"}, nil) + }, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Resources: &resourceCapabilities{ListChanged: true}, + }, + }, + { + name: "With resource templates", + configureServer: func(s *Server) { + s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) + }, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Resources: &resourceCapabilities{ListChanged: true}, + }, + }, + { + name: "With tools", + configureServer: func(s *Server) { + s.AddTool(&Tool{Name: "t"}, nil) + }, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Tools: &toolCapabilities{ListChanged: true}, + }, + }, + { + name: "With all capabilities", + configureServer: func(s *Server) { + s.AddPrompt(&Prompt{Name: "p"}, nil) + s.AddResource(&Resource{URI: "file:///r"}, nil) + s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) + s.AddTool(&Tool{Name: "t"}, nil) + }, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Prompts: &promptCapabilities{ListChanged: true}, + Resources: &resourceCapabilities{ListChanged: true}, + Tools: &toolCapabilities{ListChanged: true}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + server := NewServer("", "", nil) + tc.configureServer(server) + gotCapabilities := server.capabilities() + if diff := cmp.Diff(tc.wantCapabilities, gotCapabilities); diff != "" { + t.Errorf("capabilities() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 412d2e1d..da9c4285 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -159,8 +159,6 @@ func TestStreamableServerTransport(t *testing.T) { Capabilities: &serverCapabilities{ Completions: &completionCapabilities{}, Logging: &loggingCapabilities{}, - Prompts: &promptCapabilities{ListChanged: true}, - Resources: &resourceCapabilities{ListChanged: true}, Tools: &toolCapabilities{ListChanged: true}, }, ProtocolVersion: latestProtocolVersion, From 6c6243c0b0b624f00a8d79de14819527bb6f32b3 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 10 Jul 2025 16:57:38 -0400 Subject: [PATCH 017/221] mcp: NewClient and NewServer take Implementation (#113) Change the name and version arguments to NewClient and NewServer to a `*Implementation`, to future-proof against the spec. Fixes #109. --- README.md | 4 ++-- design/design.md | 8 ++++---- examples/completion/main.go | 2 +- examples/hello/main.go | 2 +- examples/memory/main.go | 2 +- examples/sse/main.go | 4 ++-- internal/readme/client/client.go | 2 +- internal/readme/server/server.go | 2 +- mcp/client.go | 17 ++++++++++------- mcp/cmd_test.go | 12 +++++++----- mcp/example_progress_test.go | 2 +- mcp/mcp_test.go | 30 ++++++++++++++++-------------- mcp/protocol.go | 8 ++++---- mcp/server.go | 24 ++++++++++++------------ mcp/server_example_test.go | 8 ++++---- mcp/server_test.go | 2 +- mcp/sse_example_test.go | 4 ++-- mcp/sse_test.go | 4 ++-- mcp/streamable_test.go | 8 ++++---- 19 files changed, 76 insertions(+), 69 deletions(-) diff --git a/README.md b/README.md index ce4d9f11..be91e751 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ func main() { ctx := context.Background() // Create a new client, with no features. - client := mcp.NewClient("mcp-client", "v1.0.0", nil) + client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) // Connect to a server over stdin/stdout transport := mcp.NewCommandTransport(exec.Command("myserver")) @@ -125,7 +125,7 @@ func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParam func main() { // Create a server with a single tool. - server := mcp.NewServer("greeter", "v1.0.0", nil) + server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v1.0.0"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) // Run the server over stdin/stdout, until the client disconnects diff --git a/design/design.md b/design/design.md index 610de399..56bef53e 100644 --- a/design/design.md +++ b/design/design.md @@ -323,7 +323,7 @@ Sessions are created from either `Client` or `Server` using the `Connect` method ```go type Client struct { /* ... */ } -func NewClient(name, version string, opts *ClientOptions) *Client +func NewClient(impl *Implementation, opts *ClientOptions) *Client func (*Client) Connect(context.Context, Transport) (*ClientSession, error) func (*Client) Sessions() iter.Seq[*ClientSession] // Methods for adding/removing client features are described below. @@ -338,7 +338,7 @@ func (*ClientSession) Wait() error // For example: ClientSession.ListTools. type Server struct { /* ... */ } -func NewServer(name, version string, opts *ServerOptions) *Server +func NewServer(impl *Implementation, opts *ServerOptions) *Server func (*Server) Connect(context.Context, Transport) (*ServerSession, error) func (*Server) Sessions() iter.Seq[*ServerSession] // Methods for adding/removing server features are described below. @@ -356,7 +356,7 @@ func (*ServerSession) Wait() error Here's an example of these APIs from the client side: ```go -client := mcp.NewClient("mcp-client", "v1.0.0", nil) +client := mcp.NewClient(&mcp.Implementation{Name:"mcp-client", Version:"v1.0.0"}, nil) // Connect to a server over stdin/stdout transport := mcp.NewCommandTransport(exec.Command("myserver")) session, err := client.Connect(ctx, transport) @@ -371,7 +371,7 @@ A server that can handle that client call would look like this: ```go // Create a server with a single tool. -server := mcp.NewServer("greeter", "v1.0.0", nil) +server := mcp.NewServer(&mcp.Implementation{Name:"greeter", Version:"v1.0.0"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) // Run the server over stdin/stdout, until the client disconnects. if err := server.Run(context.Background(), mcp.NewStdioTransport()); err != nil { diff --git a/examples/completion/main.go b/examples/completion/main.go index a24299bc..d6530b5d 100644 --- a/examples/completion/main.go +++ b/examples/completion/main.go @@ -40,7 +40,7 @@ func main() { // Create the MCP Server instance and assign the handler. // No server running, just showing the configuration. - _ = mcp.NewServer("myServer", "v1.0.0", &mcp.ServerOptions{ + _ = mcp.NewServer(&mcp.Implementation{Name: "server"}, &mcp.ServerOptions{ CompletionHandler: myCompletionHandler, }) diff --git a/examples/hello/main.go b/examples/hello/main.go index 4db20cc8..84cff1e0 100644 --- a/examples/hello/main.go +++ b/examples/hello/main.go @@ -42,7 +42,7 @@ func PromptHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.GetPromptP func main() { flag.Parse() - server := mcp.NewServer("greeter", "v0.0.1", nil) + server := mcp.NewServer(&mcp.Implementation{Name: "greeter"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) server.AddPrompt(&mcp.Prompt{Name: "greet"}, PromptHi) server.AddResource(&mcp.Resource{ diff --git a/examples/memory/main.go b/examples/memory/main.go index d3d78110..61ab1060 100644 --- a/examples/memory/main.go +++ b/examples/memory/main.go @@ -91,7 +91,7 @@ func main() { kb := knowledgeBase{s: kbStore} // Setup MCP server with knowledge base tools - server := mcp.NewServer("memory", "v0.0.1", nil) + server := mcp.NewServer(&mcp.Implementation{Name: "memory"}, nil) mcp.AddTool(server, &mcp.Tool{ Name: "create_entities", Description: "Create multiple new entities in the knowledge graph", diff --git a/examples/sse/main.go b/examples/sse/main.go index c93320ab..99b83e65 100644 --- a/examples/sse/main.go +++ b/examples/sse/main.go @@ -34,10 +34,10 @@ func main() { log.Fatal("http address not set") } - server1 := mcp.NewServer("greeter1", "v0.0.1", nil) + server1 := mcp.NewServer(&mcp.Implementation{Name: "greeter1"}, nil) mcp.AddTool(server1, &mcp.Tool{Name: "greet1", Description: "say hi"}, SayHi) - server2 := mcp.NewServer("greeter2", "v0.0.1", nil) + server2 := mcp.NewServer(&mcp.Implementation{Name: "greeter2"}, nil) mcp.AddTool(server2, &mcp.Tool{Name: "greet2", Description: "say hello"}, SayHi) log.Printf("MCP servers serving at %s", *httpAddr) diff --git a/internal/readme/client/client.go b/internal/readme/client/client.go index 44bc515c..666ee925 100644 --- a/internal/readme/client/client.go +++ b/internal/readme/client/client.go @@ -17,7 +17,7 @@ func main() { ctx := context.Background() // Create a new client, with no features. - client := mcp.NewClient("mcp-client", "v1.0.0", nil) + client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) // Connect to a server over stdin/stdout transport := mcp.NewCommandTransport(exec.Command("myserver")) diff --git a/internal/readme/server/server.go b/internal/readme/server/server.go index 1fe211ea..7c025412 100644 --- a/internal/readme/server/server.go +++ b/internal/readme/server/server.go @@ -24,7 +24,7 @@ func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParam func main() { // Create a server with a single tool. - server := mcp.NewServer("greeter", "v1.0.0", nil) + server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v1.0.0"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) // Run the server over stdin/stdout, until the client disconnects diff --git a/mcp/client.go b/mcp/client.go index 40d3c792..b48ad7a1 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -19,8 +19,7 @@ import ( // A Client is an MCP client, which may be connected to an MCP server // using the [Client.Connect] method. type Client struct { - name string - version string + impl *Implementation opts ClientOptions mu sync.Mutex roots *featureSet[*Root] @@ -29,15 +28,19 @@ type Client struct { receivingMethodHandler_ MethodHandler[*ClientSession] } -// NewClient creates a new Client. +// NewClient creates a new [Client]. // // Use [Client.Connect] to connect it to an MCP server. // +// The first argument must not be nil. +// // If non-nil, the provided options configure the Client. -func NewClient(name, version string, opts *ClientOptions) *Client { +func NewClient(impl *Implementation, opts *ClientOptions) *Client { + if impl == nil { + panic("nil Implementation") + } c := &Client{ - name: name, - version: version, + impl: impl, roots: newFeatureSet(func(r *Root) string { return r.URI }), sendingMethodHandler_: defaultSendingMethodHandler[*ClientSession], receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], @@ -118,7 +121,7 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e params := &InitializeParams{ ProtocolVersion: latestProtocolVersion, - ClientInfo: &implementation{Name: c.name, Version: c.version}, + ClientInfo: c.impl, Capabilities: caps, } res, err := handleSend[*InitializeResult](ctx, cs, methodInitialize, params) diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index bfae0c60..bc149f4c 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -32,7 +32,7 @@ func TestMain(m *testing.M) { func runServer() { ctx := context.Background() - server := mcp.NewServer("greeter", "v0.0.1", nil) + server := mcp.NewServer(testImpl, nil) mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) if err := server.Run(ctx, mcp.NewStdioTransport()); err != nil { log.Fatal(err) @@ -40,7 +40,7 @@ func runServer() { } func TestServerRunContextCancel(t *testing.T) { - server := mcp.NewServer("greeter", "v0.0.1", nil) + server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v0.0.1"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) ctx, cancel := context.WithCancel(context.Background()) @@ -55,7 +55,7 @@ func TestServerRunContextCancel(t *testing.T) { }() // send a ping to the server to ensure it's running - client := mcp.NewClient("client", "v0.0.1", nil) + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) session, err := client.Connect(ctx, clientTransport) if err != nil { t.Fatal(err) @@ -87,7 +87,7 @@ func TestServerInterrupt(t *testing.T) { cmd := createServerCommand(t) - client := mcp.NewClient("client", "v0.0.1", nil) + client := mcp.NewClient(testImpl, nil) session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) if err != nil { t.Fatal(err) @@ -125,7 +125,7 @@ func TestCmdTransport(t *testing.T) { cmd := createServerCommand(t) - client := mcp.NewClient("client", "v0.0.1", nil) + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) if err != nil { t.Fatal(err) @@ -174,3 +174,5 @@ func requireExec(t *testing.T) { t.Skip("unsupported OS") } } + +var testImpl = &mcp.Implementation{Name: "test", Version: "v1.0.0"} diff --git a/mcp/example_progress_test.go b/mcp/example_progress_test.go index 902b2347..6c771e20 100644 --- a/mcp/example_progress_test.go +++ b/mcp/example_progress_test.go @@ -16,7 +16,7 @@ var nextProgressToken atomic.Int64 // This middleware function adds a progress token to every outgoing request // from the client. func Example_progressMiddleware() { - c := mcp.NewClient("test", "v1", nil) + c := mcp.NewClient(testImpl, nil) c.AddSendingMiddleware(addProgressToken[*mcp.ClientSession]) _ = c } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 70c79b58..1952e8d0 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -64,7 +64,7 @@ func TestEndToEnd(t *testing.T) { notificationChans["progress_server"] <- 0 }, } - s := NewServer("testServer", "v1.0.0", sopts) + s := NewServer(testImpl, sopts) AddTool(s, &Tool{ Name: "greet", Description: "say hi", @@ -125,7 +125,7 @@ func TestEndToEnd(t *testing.T) { notificationChans["progress_client"] <- 0 }, } - c := NewClient("testClient", "v1.0.0", opts) + c := NewClient(testImpl, opts) rootAbs, err := filepath.Abs(filepath.FromSlash("testdata/files")) if err != nil { t.Fatal(err) @@ -510,7 +510,7 @@ func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *Clien ctx := context.Background() ct, st := NewInMemoryTransports() - s := NewServer("testServer", "v1.0.0", nil) + s := NewServer(testImpl, nil) if config != nil { config(s) } @@ -519,7 +519,7 @@ func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *Clien t.Fatal(err) } - c := NewClient("testClient", "v1.0.0", nil) + c := NewClient(testImpl, nil) cs, err := c.Connect(ctx, ct) if err != nil { t.Fatal(err) @@ -562,13 +562,13 @@ func TestBatching(t *testing.T) { ctx := context.Background() ct, st := NewInMemoryTransports() - s := NewServer("testServer", "v1.0.0", nil) + s := NewServer(testImpl, nil) _, err := s.Connect(ctx, st) if err != nil { t.Fatal(err) } - c := NewClient("testClient", "v1.0.0", nil) + c := NewClient(testImpl, nil) // TODO: this test is broken, because increasing the batch size here causes // 'initialize' to block. Therefore, we can only test with a size of 1. // Since batching is being removed, we can probably just delete this. @@ -632,7 +632,7 @@ func TestMiddleware(t *testing.T) { ctx := context.Background() ct, st := NewInMemoryTransports() - s := NewServer("testServer", "v1.0.0", nil) + s := NewServer(testImpl, nil) ss, err := s.Connect(ctx, st) if err != nil { t.Fatal(err) @@ -656,7 +656,7 @@ func TestMiddleware(t *testing.T) { s.AddSendingMiddleware(traceCalls[*ServerSession](&sbuf, "S1"), traceCalls[*ServerSession](&sbuf, "S2")) s.AddReceivingMiddleware(traceCalls[*ServerSession](&sbuf, "R1"), traceCalls[*ServerSession](&sbuf, "R2")) - c := NewClient("testClient", "v1.0.0", nil) + c := NewClient(testImpl, nil) c.AddSendingMiddleware(traceCalls[*ClientSession](&cbuf, "S1"), traceCalls[*ClientSession](&cbuf, "S2")) c.AddReceivingMiddleware(traceCalls[*ClientSession](&cbuf, "R1"), traceCalls[*ClientSession](&cbuf, "R2")) @@ -741,13 +741,13 @@ func TestNoJSONNull(t *testing.T) { var logbuf safeBuffer ct = NewLoggingTransport(ct, &logbuf) - s := NewServer("testServer", "v1.0.0", nil) + s := NewServer(testImpl, nil) ss, err := s.Connect(ctx, st) if err != nil { t.Fatal(err) } - c := NewClient("testClient", "v1.0.0", nil) + c := NewClient(testImpl, nil) cs, err := c.Connect(ctx, ct) if err != nil { t.Fatal(err) @@ -810,7 +810,7 @@ func TestKeepAlive(t *testing.T) { serverOpts := &ServerOptions{ KeepAlive: 100 * time.Millisecond, } - s := NewServer("testServer", "v1.0.0", serverOpts) + s := NewServer(testImpl, serverOpts) AddTool(s, greetTool(), sayHi) ss, err := s.Connect(ctx, st) @@ -822,7 +822,7 @@ func TestKeepAlive(t *testing.T) { clientOpts := &ClientOptions{ KeepAlive: 100 * time.Millisecond, } - c := NewClient("testClient", "v1.0.0", clientOpts) + c := NewClient(testImpl, clientOpts) cs, err := c.Connect(ctx, ct) if err != nil { t.Fatal(err) @@ -855,7 +855,7 @@ func TestKeepAliveFailure(t *testing.T) { ct, st := NewInMemoryTransports() // Server without keepalive (to test one-sided keepalive) - s := NewServer("testServer", "v1.0.0", nil) + s := NewServer(testImpl, nil) AddTool(s, greetTool(), sayHi) ss, err := s.Connect(ctx, st) if err != nil { @@ -866,7 +866,7 @@ func TestKeepAliveFailure(t *testing.T) { clientOpts := &ClientOptions{ KeepAlive: 50 * time.Millisecond, } - c := NewClient("testClient", "v1.0.0", clientOpts) + c := NewClient(testImpl, clientOpts) cs, err := c.Connect(ctx, ct) if err != nil { t.Fatal(err) @@ -895,3 +895,5 @@ func TestKeepAliveFailure(t *testing.T) { t.Errorf("expected connection to be closed by keepalive, but it wasn't. Last error: %v", err) } + +var testImpl = &Implementation{Name: "test", Version: "v1.0.0"} diff --git a/mcp/protocol.go b/mcp/protocol.go index eba9e73d..4f47c961 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -292,7 +292,7 @@ type InitializeParams struct { // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` Capabilities *ClientCapabilities `json:"capabilities"` - ClientInfo *implementation `json:"clientInfo"` + ClientInfo *Implementation `json:"clientInfo"` // The latest version of the Model Context Protocol that the client supports. // The client may decide to support older versions as well. ProtocolVersion string `json:"protocolVersion"` @@ -318,7 +318,7 @@ type InitializeResult struct { // may not match the version that the client requested. If the client cannot // support this version, it must disconnect. ProtocolVersion string `json:"protocolVersion"` - ServerInfo *implementation `json:"serverInfo"` + ServerInfo *Implementation `json:"serverInfo"` } type InitializedParams struct { @@ -863,9 +863,9 @@ func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) // TODO(jba): add ElicitRequest and related types. -// Describes the name and version of an MCP implementation, with an optional +// An Implementation describes the name and version of an MCP implementation, with an optional // title for UI representation. -type implementation struct { +type Implementation struct { // Intended for programmatic or logical use, but used as a display name in past // specs or fallback (if title isn't present). Name string `json:"name"` diff --git a/mcp/server.go b/mcp/server.go index 16cc8d6b..060ae538 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -32,9 +32,8 @@ const DefaultPageSize = 1000 // sessions by using [Server.Start] or [Server.Run]. type Server struct { // fixed at creation - name string - version string - opts ServerOptions + impl *Implementation + opts ServerOptions mu sync.Mutex prompts *featureSet[*serverPrompt] @@ -68,13 +67,18 @@ type ServerOptions struct { } // NewServer creates a new MCP server. The resulting server has no features: -// add features using [Server.AddTools], [Server.AddPrompts] and [Server.AddResources]. +// add features using the various Server.AddXXX methods, and the [AddTool] function. // // The server can be connected to one or more MCP clients using [Server.Start] // or [Server.Run]. // -// If non-nil, the provided options is used to configure the server. -func NewServer(name, version string, opts *ServerOptions) *Server { +// The first argument must not be nil. +// +// If non-nil, the provided options are used to configure the server. +func NewServer(impl *Implementation, opts *ServerOptions) *Server { + if impl == nil { + panic("nil Implementation") + } if opts == nil { opts = new(ServerOptions) } @@ -85,8 +89,7 @@ func NewServer(name, version string, opts *ServerOptions) *Server { opts.PageSize = DefaultPageSize } return &Server{ - name: name, - version: version, + impl: impl, opts: *opts, prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), @@ -686,10 +689,7 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam ProtocolVersion: version, Capabilities: ss.server.capabilities(), Instructions: ss.server.opts.Instructions, - ServerInfo: &implementation{ - Name: ss.server.name, - Version: ss.server.version, - }, + ServerInfo: ss.server.impl, }, nil } diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index fd6eea00..3ab7a2a4 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -28,7 +28,7 @@ func ExampleServer() { ctx := context.Background() clientTransport, serverTransport := mcp.NewInMemoryTransports() - server := mcp.NewServer("greeter", "v0.0.1", nil) + server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v0.0.1"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) serverSession, err := server.Connect(ctx, serverTransport) @@ -36,7 +36,7 @@ func ExampleServer() { log.Fatal(err) } - client := mcp.NewClient("client", "v0.0.1", nil) + client := mcp.NewClient(&mcp.Implementation{Name: "client"}, nil) clientSession, err := client.Connect(ctx, clientTransport) if err != nil { log.Fatal(err) @@ -59,8 +59,8 @@ func ExampleServer() { // createSessions creates and connects an in-memory client and server session for testing purposes. func createSessions(ctx context.Context) (*mcp.ClientSession, *mcp.ServerSession, *mcp.Server) { - server := mcp.NewServer("server", "v0.0.1", nil) - client := mcp.NewClient("client", "v0.0.1", nil) + server := mcp.NewServer(testImpl, nil) + client := mcp.NewClient(testImpl, nil) serverTransport, clientTransport := mcp.NewInMemoryTransports() serverSession, err := server.Connect(ctx, serverTransport) if err != nil { diff --git a/mcp/server_test.go b/mcp/server_test.go index cc94003c..d4243d7c 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -306,7 +306,7 @@ func TestServerCapabilities(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - server := NewServer("", "", nil) + server := NewServer(testImpl, nil) tc.configureServer(server) gotCapabilities := server.capabilities() if diff := cmp.Diff(tc.wantCapabilities, gotCapabilities); diff != "" { diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index 816e0134..d8ce939b 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -27,7 +27,7 @@ func Add(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsF } func ExampleSSEHandler() { - server := mcp.NewServer("adder", "v0.0.1", nil) + server := mcp.NewServer(&mcp.Implementation{Name: "adder", Version: "v0.0.1"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "add", Description: "add two numbers"}, Add) handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { return server }) @@ -36,7 +36,7 @@ func ExampleSSEHandler() { ctx := context.Background() transport := mcp.NewSSEClientTransport(httpServer.URL, nil) - client := mcp.NewClient("test", "v1.0.0", nil) + client := mcp.NewClient(&mcp.Implementation{Name: "test", Version: "v1.0.0"}, nil) cs, err := client.Connect(ctx, transport) if err != nil { log.Fatal(err) diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 153185d3..846d68c0 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -20,7 +20,7 @@ func TestSSEServer(t *testing.T) { for _, closeServerFirst := range []bool{false, true} { t.Run(fmt.Sprintf("closeServerFirst=%t", closeServerFirst), func(t *testing.T) { ctx := context.Background() - server := NewServer("testServer", "v1.0.0", nil) + server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet"}, sayHi) sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }) @@ -47,7 +47,7 @@ func TestSSEServer(t *testing.T) { HTTPClient: customClient, }) - c := NewClient("testClient", "v1.0.0", nil) + c := NewClient(testImpl, nil) cs, err := c.Connect(ctx, clientTransport) if err != nil { t.Fatal(err) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index da9c4285..e0e85a1e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -32,7 +32,7 @@ func TestStreamableTransports(t *testing.T) { ctx := context.Background() // 1. Create a server with a simple "greet" tool. - server := NewServer("testServer", "v1.0.0", nil) + server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a // cookie-checking middleware. @@ -65,7 +65,7 @@ func TestStreamableTransports(t *testing.T) { transport := NewStreamableClientTransport(httpServer.URL, &StreamableClientTransportOptions{ HTTPClient: httpClient, }) - client := NewClient("testClient", "v1.0.0", nil) + client := NewClient(testImpl, nil) session, err := client.Connect(ctx, transport) if err != nil { t.Fatalf("client.Connect() failed: %v", err) @@ -162,7 +162,7 @@ func TestStreamableServerTransport(t *testing.T) { Tools: &toolCapabilities{ListChanged: true}, }, ProtocolVersion: latestProtocolVersion, - ServerInfo: &implementation{Name: "testServer", Version: "v1.0.0"}, + ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, }, nil) initializedMsg := req(0, "initialized", &InitializedParams{}) initialize := step{ @@ -328,7 +328,7 @@ func TestStreamableServerTransport(t *testing.T) { t.Run(test.name, func(t *testing.T) { // Create a server containing a single tool, which runs the test tool // behavior, if any. - server := NewServer("testServer", "v1.0.0", nil) + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) AddTool(server, &Tool{Name: "tool"}, func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[any], error) { if test.tool != nil { test.tool(t, ctx, ss) From 78a66a438a4ef06fa50c8e5d6c559669f26ccd3b Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 11 Jul 2025 08:08:14 -0400 Subject: [PATCH 018/221] README for v0.2.0 (#118) Preliminary readme for v0.2.0. We will make a GH "release" with the release notes. --- README.md | 23 +++++++++-------------- design/design.md | 28 ++++------------------------ internal/readme/README.src.md | 23 +++++++++-------------- mcp/server.go | 2 +- 4 files changed, 23 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index be91e751..ef74035e 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,11 @@ -# MCP Go SDK +# MCP Go SDK v0.2.0 ***BREAKING CHANGES*** -The latest version contains breaking changes: - -- Server.AddTools is replaced by Server.AddTool. - -- NewServerTool is replaced by AddTool. AddTool takes a Tool rather than a name and description, so you can - set any field on the Tool that you want before associating it with a handler. - -- Tool options have been removed. If you don't want AddTool to infer a JSON Schema for you, you can construct one - as a struct literal, or using any other code that suits you. - -- AddPrompts, AddResources and AddResourceTemplates are similarly replaced by singular methods which pair the - feature with a handler. The ServerXXX types have been removed. +This version contains breaking changes. +See the [release notes]( +https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.2.0) for details. [![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) @@ -41,7 +32,7 @@ open-ended discussion. See CONTRIBUTING.md for details. ## Package documentation -The SDK consists of two importable packages: +The SDK consists of three importable packages: - The [`github.com/modelcontextprotocol/go-sdk/mcp`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp) @@ -51,6 +42,10 @@ The SDK consists of two importable packages: [`github.com/modelcontextprotocol/go-sdk/jsonschema`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonschema) package provides an implementation of [JSON Schema](https://json-schema.org/), used for MCP tool input and output schema. +- The + [`github.com/modelcontextprotocol/jsonrpc`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonschema) package is for users implementing + their own transports. + ## Example diff --git a/design/design.md b/design/design.md index 56bef53e..8ab7c152 100644 --- a/design/design.md +++ b/design/design.md @@ -692,13 +692,9 @@ func (cs *ClientSession) CallTool(context.Context, *CallToolParams[json.RawMessa func CallTool[TArgs any](context.Context, *ClientSession, *CallToolParams[TArgs]) (*CallToolResult, error) ``` -**Differences from mcp-go**: using variadic options to configure tools was significantly inspired by mcp-go. However, the distinction between `ToolOption` and `SchemaOption` allows for recursive application of schema options. For example, that limitation is visible in [this code](https://github.com/DCjanus/dida365-mcp-server/blob/master/cmd/mcp/tools.go#L315), which must resort to untyped maps to express a nested schema. +**Differences from mcp-go**: We provide a full JSON Schema implementation for validating tool input schemas against incoming arguments. The `jsonschema.Schema` type provides exported features for all keywords in the JSON Schema draft2020-12 spec. Tool definers can use it to construct any schema they want. The `jsonschema.For[T]` function can infer a schema from a Go struct. These combined features eliminate the need for variadic arguments to construct tool schemas. -Additionally, the `NewServerTool` helper provides a means for building a tool from a Go function using reflection, that automatically handles parsing and validation of inputs. - -We provide a full JSON Schema implementation for validating tool input schemas against incoming arguments. The `jsonschema.Schema` type provides exported features for all keywords in the JSON Schema draft2020-12 spec. Tool definers can use it to construct any schema they want, so there is no need to provide options for all of them. When combined with schema inference from input structs, we found that we needed only three options to cover the common cases, instead of mcp-go's 23. For example, we will provide `Enum`, which occurs 125 times in open source code, but not MinItems, MinLength or MinProperties, which each occur only once (and in an SDK that wraps mcp-go). - -For registering tools, we provide only `AddTools`; mcp-go's `SetTools`, `AddTool`, `AddSessionTool`, and `AddSessionTools` are deemed unnecessary. (Similarly for Delete/Remove). +For registering tools, we provide only a `Server.AddTool` method; mcp-go's `SetTools`, `AddTool`, `AddSessionTool`, and `AddSessionTools` are deemed unnecessary. (Similarly for Delete/Remove). The `AddTool` generic function combines schema inference with registration, providing a easy way to register many tools. ### Prompts @@ -722,7 +718,7 @@ server.RemovePrompts("code_review") Client sessions can call the spec method `ListPrompts` or the iterator method `Prompts` to list the available prompts, and the spec method `GetPrompt` to get one. -**Differences from mcp-go**: We provide a `NewServerPrompt` helper to bind a prompt handler to a Go function using reflection to derive its arguments. We provide `RemovePrompts` to remove prompts from the server. +**Differences from mcp-go**: We provide `RemovePrompts` to remove prompts from the server. ### Resources and resource templates @@ -746,25 +742,9 @@ func (s *Server) RemoveResourceTemplates(uriTemplates ...string) The `ReadResource` method finds a resource or resource template matching the argument URI and calls its associated handler. -To read files from the local filesystem, we recommend using `FileResourceHandler` to construct a handler: - -```go -// FileResourceHandler returns a ResourceHandler that reads paths using dir as a root directory. -// It protects against path traversal attacks. -// It will not read any file that is not in the root set of the client requesting the resource. -func (*Server) FileResourceHandler(dir string) ResourceHandler -``` - -Here is an example: - -```go -// Safely read "/public/puppies.txt". -s.AddResource(&mcp.Resource{URI: "file:///puppies.txt"}, s.FileReadResourceHandler("/public")) -``` - Server sessions also support the spec methods `ListResources` and `ListResourceTemplates`, and the corresponding iterator methods `Resources` and `ResourceTemplates`. -**Differences from mcp-go**: for symmetry with tools and prompts, we use `AddResources` rather than `AddResource`. Additionally, the `ResourceHandler` returns a `ReadResourceResult`, rather than just its content, for compatibility with future evolution of the spec. +**Differences from mcp-go**: The `ResourceHandler` returns a `ReadResourceResult`, rather than just its content, for compatibility with future evolution of the spec. #### Subscriptions diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index 11d63110..c03d74d6 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -1,19 +1,10 @@ -# MCP Go SDK +# MCP Go SDK v0.2.0 ***BREAKING CHANGES*** -The latest version contains breaking changes: - -- Server.AddTools is replaced by Server.AddTool. - -- NewServerTool is replaced by AddTool. AddTool takes a Tool rather than a name and description, so you can - set any field on the Tool that you want before associating it with a handler. - -- Tool options have been removed. If you don't want AddTool to infer a JSON Schema for you, you can construct one - as a struct literal, or using any other code that suits you. - -- AddPrompts, AddResources and AddResourceTemplates are similarly replaced by singular methods which pair the - feature with a handler. The ServerXXX types have been removed. +This version contains breaking changes. +See the [release notes]( +https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.2.0) for details. [![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) @@ -40,7 +31,7 @@ open-ended discussion. See CONTRIBUTING.md for details. ## Package documentation -The SDK consists of two importable packages: +The SDK consists of three importable packages: - The [`github.com/modelcontextprotocol/go-sdk/mcp`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp) @@ -50,6 +41,10 @@ The SDK consists of two importable packages: [`github.com/modelcontextprotocol/go-sdk/jsonschema`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonschema) package provides an implementation of [JSON Schema](https://json-schema.org/), used for MCP tool input and output schema. +- The + [`github.com/modelcontextprotocol/jsonrpc`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonschema) package is for users implementing + their own transports. + ## Example diff --git a/mcp/server.go b/mcp/server.go index 060ae538..6b287ad7 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -388,7 +388,7 @@ func (s *Server) lookupResourceHandler(uri string) (ResourceHandler, string, boo // The dir argument should be a filesystem path. It need not be absolute, but // that is recommended to avoid a dependency on the current working directory (the // check against client roots is done with an absolute path). If dir is not absolute -// and the current working directory is unavailable, FileResourceHandler panics. +// and the current working directory is unavailable, fileResourceHandler panics. // // Lexical path traversal attacks, where the path has ".." elements that escape dir, // are always caught. Go 1.24 and above also protects against symlink-based attacks, From 0f596bde17fb562a49590fc41125ce88916a8c71 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 11 Jul 2025 11:27:33 -0400 Subject: [PATCH 019/221] jsonschema: improve doc (#122) Improve the package documentation. Add a README and LICENSE, anticipating the relocation of this package to its own module elsewhere. --- jsonschema/LICENSE | 21 +++++++++++++++ jsonschema/README.md | 39 ++++++++++++++++++++++++++++ jsonschema/doc.go | 61 ++++++++++++++++++++++++++++++++------------ jsonschema/infer.go | 3 ++- 4 files changed, 106 insertions(+), 18 deletions(-) create mode 100644 jsonschema/LICENSE create mode 100644 jsonschema/README.md diff --git a/jsonschema/LICENSE b/jsonschema/LICENSE new file mode 100644 index 00000000..508be926 --- /dev/null +++ b/jsonschema/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Go MCP SDK Authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/jsonschema/README.md b/jsonschema/README.md new file mode 100644 index 00000000..f316bedd --- /dev/null +++ b/jsonschema/README.md @@ -0,0 +1,39 @@ +TODO: this file should live at the root of the jsonschema-go module, +above the jsonschema package. + +# JSON Schema for GO + +This module implements the [JSON Schema](https://json-schema.org/) specification. +The `jsonschema` package supports creating schemas, validating JSON values +against a schema, and inferring a schema from a Go struct. See the package +documentation for usage. + +## Contributing + +This module welcomes external contributions. +It has no dependencies outside of the standard library, and can be built with +the standard Go toolchain. Run `go test ./...` at the module root to run all +the tests. + +## Issues + +This project uses the [GitHub issue +tracker](https://github.com/TODO/jsonschema-go/issues) for bug reports, feature requests, and other issues. + +Please [report +bugs](https://github.com/TODO/jsonschema-go/issues/new). If the SDK is +not working as you expected, it is likely due to a bug or inadequate +documentation, and reporting an issue will help us address this shortcoming. + +When reporting a bug, make sure to answer these five questions: + +1. What did you do? +2. What did you see? +3. What did you expect to see? +4. What version of the Go MCP SDK are you using? +5. What version of Go are you using (`go version`)? + +## License + +This project is licensed under the MIT license. See the LICENSE file for details. + diff --git a/jsonschema/doc.go b/jsonschema/doc.go index f25b000a..135bdf50 100644 --- a/jsonschema/doc.go +++ b/jsonschema/doc.go @@ -5,22 +5,30 @@ /* Package jsonschema is an implementation of the [JSON Schema specification], a JSON-based format for describing the structure of JSON data. -The package can be used to read schemas for code generation, and to validate data using the -draft 2020-12 specification. Validation with other drafts or custom meta-schemas -is not supported. +The package can be used to read schemas for code generation, and to validate +data using the draft 2020-12 specification. Validation with other drafts +or custom meta-schemas is not supported. -Construct a [Schema] as you would any Go struct (for example, by writing a struct -literal), or unmarshal a JSON schema into a [Schema] in the usual way (with [encoding/json], -for instance). It can then be used for code generation or other purposes without -further processing. +Construct a [Schema] as you would any Go struct (for example, by writing +a struct literal), or unmarshal a JSON schema into a [Schema] in the usual +way (with [encoding/json], for instance). It can then be used for code +generation or other purposes without further processing. +You can also infer a schema from a Go struct. + +# Resolution + +A Schema can refer to other schemas, both inside and outside itself. These +references must be resolved before a schema can be used for validation. +Call [Schema.Resolve] to obtain a resolved schema (called a [Resolved]). +If the schema has external references, pass a [ResolveOptions] with a [Loader] +to load them. To validate default values in a schema, set +[ResolveOptions.ValidateDefaults] to true. # Validation -Before using a Schema to validate a JSON value, you must first resolve it by calling -[Schema.Resolve]. -The call [Resolved.Validate] on the result to validate a JSON value. -The value must be a Go value that looks like the result of unmarshaling a JSON -value into an [any] or a struct. For example, the JSON value +Call [Resolved.Validate] to validate a JSON value. The value must be a +Go value that looks like the result of unmarshaling a JSON value into an +[any] or a struct. For example, the JSON value {"name": "Al", "scores": [90, 80, 100]} @@ -41,8 +49,11 @@ or as a value of this type: # Inference The [For] function returns a [Schema] describing the given Go type. -The type cannot contain any function or channel types, and any map types must have a string key. -For example, calling For on the above Player type results in this schema: +Each field in the struct becomes a property of the schema. +The values of "json" tags are respected: the field's property name is taken +from the tag, and fields omitted from the JSON are omitted from the schema as +well. +For example, `jsonschema.For[Player]()` returns this schema: { "properties": { @@ -58,16 +69,32 @@ For example, calling For on the above Player type results in this schema: } } +Use the "jsonschema" struct tag to provide a description for the property: + + type Player struct { + Name string `json:"name" jsonschema:"player name"` + Scores []int `json:"scores" jsonschema:"scores of player's games"` + } + # Deviations from the specification -Regular expressions are processed with Go's regexp package, which differs from ECMA 262, -most significantly in not supporting back-references. +Regular expressions are processed with Go's regexp package, which differs +from ECMA 262, most significantly in not supporting back-references. See [this table of differences] for more. -The value of the "format" keyword is recorded in the Schema, but is ignored during validation. +The "format" keyword described in [section 7 of the validation spec] is recorded +in the Schema, but is ignored during validation. It does not even produce [annotations]. +Use the "pattern" keyword instead: it will work more reliably across JSON Schema +implementations. See [learnjsonschema.com] for more recommendations about "format". + +The content keywords described in [section 8 of the validation spec] +are recorded in the schema, but ignored during validation. [JSON Schema specification]: https://json-schema.org +[section 7 of the validation spec]: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.7 +[section 8 of the validation spec]: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.8 +[learnjsonschema.org]: https://www.learnjsonschema.com/2020-12/format-annotation/format/ [this table of differences] https://github.com/dlclark/regexp2?tab=readme-ov-file#compare-regexp-and-regexp2 [annotations]: https://json-schema.org/draft/2020-12/json-schema-core#name-annotations */ diff --git a/jsonschema/infer.go b/jsonschema/infer.go index ae441291..654e6197 100644 --- a/jsonschema/infer.go +++ b/jsonschema/infer.go @@ -34,12 +34,13 @@ import ( // types, as they are incompatible with the JSON schema spec. // - maps with key other than 'string' // - function types +// - channel types // - complex numbers // - unsafe pointers // // It will return an error if there is a cycle in the types. // -// For recognizes struct field tags named "jsonschema". +// This function recognizes struct field tags named "jsonschema". // A jsonschema tag on a field is used as the description for the corresponding property. // For future compatibility, descriptions must not start with "WORD=", where WORD is a // sequence of non-whitespace characters. From aa9d4b28be56a4db23012db4fd46eac909db0abb Mon Sep 17 00:00:00 2001 From: Michael Pratt Date: Fri, 11 Jul 2025 17:34:18 -0400 Subject: [PATCH 020/221] jsonschema: add missing colon to link block (#126) Without this, not only is the one link not valid, the entire section is considered invalid and none of the link references are used. --- jsonschema/doc.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jsonschema/doc.go b/jsonschema/doc.go index 135bdf50..0f0ba441 100644 --- a/jsonschema/doc.go +++ b/jsonschema/doc.go @@ -94,8 +94,8 @@ are recorded in the schema, but ignored during validation. [JSON Schema specification]: https://json-schema.org [section 7 of the validation spec]: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.7 [section 8 of the validation spec]: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.8 -[learnjsonschema.org]: https://www.learnjsonschema.com/2020-12/format-annotation/format/ -[this table of differences] https://github.com/dlclark/regexp2?tab=readme-ov-file#compare-regexp-and-regexp2 +[learnjsonschema.com]: https://www.learnjsonschema.com/2020-12/format-annotation/format/ +[this table of differences]: https://github.com/dlclark/regexp2?tab=readme-ov-file#compare-regexp-and-regexp2 [annotations]: https://json-schema.org/draft/2020-12/json-schema-core#name-annotations */ package jsonschema From fe208c144fed5989a2ff952af7c56707a6c08d02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vid=20Drobni=C4=8D?= Date: Mon, 14 Jul 2025 13:01:40 +0200 Subject: [PATCH 021/221] jsonrpc: expose encoding and decoding functions (#114) Expose `jsonrpc2.MakeID`, `jsonrpc2.EncodeMessage` and `jsonrpc2.DecodeMessage` functions to allow implementing custom `mcp.Transport`. Add an example that demonstrates a custom transport implementation. Fixes #110. --- examples/custom-transport/main.go | 109 ++++++++++++++++++++++++++++++ internal/jsonrpc2/messages.go | 2 +- jsonrpc/jsonrpc.go | 19 ++++++ 3 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 examples/custom-transport/main.go diff --git a/examples/custom-transport/main.go b/examples/custom-transport/main.go new file mode 100644 index 00000000..cc4f15f3 --- /dev/null +++ b/examples/custom-transport/main.go @@ -0,0 +1,109 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bufio" + "context" + "errors" + "io" + "log" + "os" + + "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// IOTransport is a simplified implementation of a transport that communicates using +// newline-delimited JSON over an io.Reader and io.Writer. It is similar to ioTransport +// in transport.go and serves as a demonstration of how to implement a custom transport. +type IOTransport struct { + r *bufio.Reader + w io.Writer +} + +// NewIOTransport creates a new IOTransport with the given io.Reader and io.Writer. +func NewIOTransport(r io.Reader, w io.Writer) *IOTransport { + return &IOTransport{ + r: bufio.NewReader(r), + w: w, + } +} + +// ioConn is a connection that uses newlines to delimit messages. It implements [mcp.Connection]. +type ioConn struct { + r *bufio.Reader + w io.Writer +} + +// Connect implements [mcp.Transport.Connect] by creating a new ioConn. +func (t *IOTransport) Connect(ctx context.Context) (mcp.Connection, error) { + return &ioConn{ + r: t.r, + w: t.w, + }, nil +} + +// Read implements [mcp.Connection.Read], assuming messages are newline-delimited JSON. +func (t *ioConn) Read(context.Context) (jsonrpc.Message, error) { + data, err := t.r.ReadBytes('\n') + if err != nil { + return nil, err + } + + return jsonrpc.DecodeMessage(data[:len(data)-1]) +} + +// Write implements [mcp.Connection.Write], appending a newline delimiter after the message. +func (t *ioConn) Write(_ context.Context, msg jsonrpc.Message) error { + data, err := jsonrpc.EncodeMessage(msg) + if err != nil { + return err + } + + _, err1 := t.w.Write(data) + _, err2 := t.w.Write([]byte{'\n'}) + return errors.Join(err1, err2) +} + +// Close implements [mcp.Connection.Close]. Since this is a simplified example, it is a no-op. +func (t *ioConn) Close() error { + return nil +} + +// SessionID implements [mcp.Connection.SessionID]. Since this is a simplified example, +// it returns an empty session ID. +func (t *ioConn) SessionID() string { + return "" +} + +// HiArgs is the argument type for the SayHi tool. +type HiArgs struct { + Name string `json:"name" mcp:"the name to say hi to"` +} + +// SayHi is a tool handler that responds with a greeting. +func SayHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[HiArgs]) (*mcp.CallToolResultFor[struct{}], error) { + return &mcp.CallToolResultFor[struct{}]{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Hi " + params.Arguments.Name}, + }, + }, nil +} + +func main() { + server := mcp.NewServer(&mcp.Implementation{Name: "greeter"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) + + // Run the server with a custom IOTransport using stdio as the io.Reader and io.Writer. + transport := &IOTransport{ + r: bufio.NewReader(os.Stdin), + w: os.Stdout, + } + err := server.Run(context.Background(), transport) + if err != nil { + log.Println("[ERROR]: Failed to run server:", err) + } +} diff --git a/internal/jsonrpc2/messages.go b/internal/jsonrpc2/messages.go index 03371b91..2de3d4f0 100644 --- a/internal/jsonrpc2/messages.go +++ b/internal/jsonrpc2/messages.go @@ -19,7 +19,7 @@ type ID struct { // MakeID coerces the given Go value to an ID. The value is assumed to be the // default JSON marshaling of a Request identifier -- nil, float64, or string. // -// Returns an error if the value type was a valid Request ID type. +// Returns an error if the value type was not a valid Request ID type. // // TODO: ID can't be a json.Marshaler/Unmarshaler, because we want to omitzero. // Simplify this package by making ID json serializable once we can rely on diff --git a/jsonrpc/jsonrpc.go b/jsonrpc/jsonrpc.go index f175e597..1cf1202f 100644 --- a/jsonrpc/jsonrpc.go +++ b/jsonrpc/jsonrpc.go @@ -18,3 +18,22 @@ type ( // Response is a JSON-RPC response. Response = jsonrpc2.Response ) + +// MakeID coerces the given Go value to an ID. The value is assumed to be the +// default JSON marshaling of a Request identifier -- nil, float64, or string. +// +// Returns an error if the value type was not a valid Request ID type. +func MakeID(v any) (ID, error) { + return jsonrpc2.MakeID(v) +} + +// EncodeMessage serializes a JSON-RPC message to its wire format. +func EncodeMessage(msg Message) ([]byte, error) { + return jsonrpc2.EncodeMessage(msg) +} + +// DecodeMessage deserializes JSON-RPC wire format data into a Message. +// It returns either a Request or Response based on the message content. +func DecodeMessage(data []byte) (Message, error) { + return jsonrpc2.DecodeMessage(data) +} From d570022d03fb4000a40d24357df68ad08fd67093 Mon Sep 17 00:00:00 2001 From: Naoki Sega <499108+nsega@users.noreply.github.com> Date: Mon, 14 Jul 2025 05:12:38 -0700 Subject: [PATCH 022/221] mcp: add sequential thinking server example (#51) Add example of implementing mcp on [sequentialthinking server](https://github.com/modelcontextprotocol/servers/tree/main/src/sequentialthinking), from [Example Servers](https://modelcontextprotocol.io/examples). The PR includes dynamic and reflective problem-solving through structured thinking processes. The server provides tools for: - starting thinking sessions - adding sequential thoughts with step tracking - revising previous thoughts - creating alternative reasoning branches - reviewing complete processes. Also, features include thread-safe session management, adaptive planning that adjusts step counts dynamically. --- examples/sequentialthinking/README.md | 203 +++++++++ examples/sequentialthinking/main.go | 544 ++++++++++++++++++++++ examples/sequentialthinking/main_test.go | 548 +++++++++++++++++++++++ 3 files changed, 1295 insertions(+) create mode 100644 examples/sequentialthinking/README.md create mode 100644 examples/sequentialthinking/main.go create mode 100644 examples/sequentialthinking/main_test.go diff --git a/examples/sequentialthinking/README.md b/examples/sequentialthinking/README.md new file mode 100644 index 00000000..40987b15 --- /dev/null +++ b/examples/sequentialthinking/README.md @@ -0,0 +1,203 @@ +# Sequential Thinking MCP Server + +This example shows a Model Context Protocol (MCP) server that enables dynamic and reflective problem-solving through structured thinking processes. It helps break down complex problems into manageable, sequential thought steps with support for revision and branching. + +## Features + +The server provides three main tools for managing thinking sessions: + +### 1. Start Thinking (`start_thinking`) + +Begins a new sequential thinking session for a complex problem. + +**Parameters:** + +- `problem` (string): The problem or question to think about +- `sessionId` (string, optional): Custom session identifier +- `estimatedSteps` (int, optional): Initial estimate of thinking steps needed + +### 2. Continue Thinking (`continue_thinking`) + +Adds the next thought step, revises previous steps, or creates alternative branches. + +**Parameters:** + +- `sessionId` (string): The thinking session to continue +- `thought` (string): The current thought or analysis +- `nextNeeded` (bool, optional): Whether another thinking step is needed +- `reviseStep` (int, optional): Step number to revise (1-based) +- `createBranch` (bool, optional): Create an alternative reasoning path +- `estimatedTotal` (int, optional): Update total estimated steps + +### 3. Review Thinking (`review_thinking`) + +Provides a complete review of the thinking process for a session. + +**Parameters:** + +- `sessionId` (string): The session to review + +## Resources + +### Thinking History (`thinking://sessions` or `thinking://{sessionId}`) + +Access thinking session data and history in JSON format. + +- `thinking://sessions` - List all thinking sessions +- `thinking://{sessionId}` - Get specific session details + +## Core Concepts + +### Sequential Processing + +Problems are broken down into numbered thought steps that build upon each other, maintaining context and allowing for systematic analysis. + +### Dynamic Revision + +Any previous thought step can be revised and updated, with the system tracking which thoughts have been modified. + +### Alternative Branching + +Create alternative reasoning paths to explore different approaches to the same problem, allowing for comparative analysis. + +### Adaptive Planning + +The estimated number of thinking steps can be adjusted dynamically as understanding of the problem evolves. + +## Running the Server + +### Standard I/O Mode + +```bash +go run . +``` + +### HTTP Mode + +```bash +go run . -http :8080 +``` + +## Example Usage + +### Starting a Thinking Session + +```json +{ + "method": "tools/call", + "params": { + "name": "start_thinking", + "arguments": { + "problem": "How should I design a scalable microservices architecture?", + "sessionId": "architecture_design", + "estimatedSteps": 8 + } + } +} +``` + +### Adding Sequential Thoughts + +```json +{ + "method": "tools/call", + "params": { + "name": "continue_thinking", + "arguments": { + "sessionId": "architecture_design", + "thought": "First, I need to identify the core business domains and their boundaries to determine service decomposition." + } + } +} +``` + +### Revising a Previous Step + +```json +{ + "method": "tools/call", + "params": { + "name": "continue_thinking", + "arguments": { + "sessionId": "architecture_design", + "thought": "Actually, before identifying domains, I should analyze the current system's pain points and requirements.", + "reviseStep": 1 + } + } +} +``` + +### Creating an Alternative Branch + +```json +{ + "method": "tools/call", + "params": { + "name": "continue_thinking", + "arguments": { + "sessionId": "architecture_design", + "thought": "Alternative approach: Start with a monolith-first strategy and extract services gradually.", + "createBranch": true + } + } +} +``` + +### Completing the Thinking Process + +```json +{ + "method": "tools/call", + "params": { + "name": "continue_thinking", + "arguments": { + "sessionId": "architecture_design", + "thought": "Based on this analysis, I recommend starting with 3 core services: User Management, Order Processing, and Inventory Management.", + "nextNeeded": false + } + } +} +``` + +### Reviewing the Complete Process + +```json +{ + "method": "tools/call", + "params": { + "name": "review_thinking", + "arguments": { + "sessionId": "architecture_design" + } + } +} +``` + +## Session State Management + +Each thinking session maintains: + +- **Session metadata**: ID, problem statement, creation time, current status +- **Thought sequence**: Ordered list of thoughts with timestamps and revision history +- **Progress tracking**: Current step and estimated total steps +- **Branch relationships**: Links to alternative reasoning paths +- **Status management**: Active, completed, or paused sessions + +## Use Cases + +**Ideal for:** + +- Complex problem analysis requiring step-by-step breakdown +- Design decisions needing systematic evaluation +- Scenarios where initial scope is unclear and may evolve +- Problems requiring alternative approach exploration +- Situations needing detailed reasoning documentation + +**Examples:** + +- Software architecture design +- Research methodology planning +- Strategic business decisions +- Technical troubleshooting +- Creative problem solving +- Academic research planning diff --git a/examples/sequentialthinking/main.go b/examples/sequentialthinking/main.go new file mode 100644 index 00000000..4830bb0a --- /dev/null +++ b/examples/sequentialthinking/main.go @@ -0,0 +1,544 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "crypto/rand" + "encoding/json" + "flag" + "fmt" + "log" + "maps" + "net/http" + "net/url" + "os" + "slices" + "strings" + "sync" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") + +// A Thought is a single step in the thinking process. +type Thought struct { + // Index of the thought within the session (1-based). + Index int `json:"index"` + // Content of the thought. + Content string `json:"content"` + // Time the thought was created. + Created time.Time `json:"created"` + // Whether the thought has been revised. + Revised bool `json:"revised"` + // Index of parent thought, or nil if this is a root for branching. + ParentIndex *int `json:"parentIndex,omitempty"` +} + +// A ThinkingSession is an active thinking session. +type ThinkingSession struct { + // Globally unique ID of the session. + ID string `json:"id"` + // Problem to solve. + Problem string `json:"problem"` + // Thoughts in the session. + Thoughts []*Thought `json:"thoughts"` + // Current thought index. + CurrentThought int `json:"currentThought"` + // Estimated total number of thoughts. + EstimatedTotal int `json:"estimatedTotal"` + // Status of the session. + Status string `json:"status"` // "active", "completed", "paused" + // Time the session was created. + Created time.Time `json:"created"` + // Time the session was last active. + LastActivity time.Time `json:"lastActivity"` + // Branches in the session. Alternative thought paths. + Branches []string `json:"branches,omitempty"` + // Version for optimistic concurrency control. + Version int `json:"version"` +} + +// clone returns a deep copy of the ThinkingSession. +func (s *ThinkingSession) clone() *ThinkingSession { + sessionCopy := *s + sessionCopy.Thoughts = deepCopyThoughts(s.Thoughts) + sessionCopy.Branches = slices.Clone(s.Branches) + return &sessionCopy +} + +// A SessionStore is a global session store (in a real implementation, this might be a database). +// +// Locking Strategy: +// The SessionStore uses a RWMutex to protect the sessions map from concurrent access. +// All ThinkingSession modifications happen on deep copies, never on shared instances. +// This means: +// - Read locks protect map access. +// - Write locks protect map modifications (adding/removing/replacing sessions) +// - Session field modifications always happen on local copies via CompareAndSwap +// - No shared ThinkingSession state is ever modified directly +type SessionStore struct { + mu sync.RWMutex + sessions map[string]*ThinkingSession // key is session ID +} + +// NewSessionStore creates a new session store for managing thinking sessions. +func NewSessionStore() *SessionStore { + return &SessionStore{ + sessions: make(map[string]*ThinkingSession), + } +} + +// Session retrieves a thinking session by ID, returning the session and whether it exists. +func (s *SessionStore) Session(id string) (*ThinkingSession, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + session, exists := s.sessions[id] + return session, exists +} + +// SetSession stores or updates a thinking session in the store. +func (s *SessionStore) SetSession(session *ThinkingSession) { + s.mu.Lock() + defer s.mu.Unlock() + s.sessions[session.ID] = session +} + +// CompareAndSwap atomically updates a session if the version matches. +// Returns true if the update succeeded, false if there was a version mismatch. +// +// This method implements optimistic concurrency control: +// 1. Read lock to safely access the map and copy the session +// 2. Deep copy the session (all modifications happen on this copy) +// 3. Release read lock and apply updates to the copy +// 4. Write lock to check version and atomically update if unchanged +// +// The read lock in step 1 is necessary to prevent map access races, +// not to protect ThinkingSession fields (which are never modified in-place). +func (s *SessionStore) CompareAndSwap(sessionID string, updateFunc func(*ThinkingSession) (*ThinkingSession, error)) error { + for { + // Get current session + s.mu.RLock() + current, exists := s.sessions[sessionID] + if !exists { + s.mu.RUnlock() + return fmt.Errorf("session %s not found", sessionID) + } + // Create a deep copy + sessionCopy := current.clone() + oldVersion := current.Version + s.mu.RUnlock() + + // Apply the update + updated, err := updateFunc(sessionCopy) + if err != nil { + return err + } + + // Try to save + s.mu.Lock() + current, exists = s.sessions[sessionID] + if !exists { + s.mu.Unlock() + return fmt.Errorf("session %s not found", sessionID) + } + if current.Version != oldVersion { + // Version mismatch, retry + s.mu.Unlock() + continue + } + updated.Version = oldVersion + 1 + s.sessions[sessionID] = updated + s.mu.Unlock() + return nil + } +} + +// Sessions returns all thinking sessions in the store. +func (s *SessionStore) Sessions() []*ThinkingSession { + s.mu.RLock() + defer s.mu.RUnlock() + return slices.Collect(maps.Values(s.sessions)) +} + +// SessionsSnapshot returns a deep copy of all sessions for safe concurrent access. +func (s *SessionStore) SessionsSnapshot() []*ThinkingSession { + s.mu.RLock() + defer s.mu.RUnlock() + + var sessions []*ThinkingSession + for _, session := range s.sessions { + sessions = append(sessions, session.clone()) + } + return sessions +} + +// SessionSnapshot returns a deep copy of a session for safe concurrent access. +// The second return value reports whether a session with the given id exists. +func (s *SessionStore) SessionSnapshot(id string) (*ThinkingSession, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + session, exists := s.sessions[id] + if !exists { + return nil, false + } + + return session.clone(), true +} + +var store = NewSessionStore() + +// StartThinkingArgs are the arguments for starting a new thinking session. +type StartThinkingArgs struct { + Problem string `json:"problem"` + SessionID string `json:"sessionId,omitempty"` + EstimatedSteps int `json:"estimatedSteps,omitempty"` +} + +// ContinueThinkingArgs are the arguments for continuing a thinking session. +type ContinueThinkingArgs struct { + SessionID string `json:"sessionId"` + Thought string `json:"thought"` + NextNeeded *bool `json:"nextNeeded,omitempty"` + ReviseStep *int `json:"reviseStep,omitempty"` + CreateBranch bool `json:"createBranch,omitempty"` + EstimatedTotal int `json:"estimatedTotal,omitempty"` +} + +// ReviewThinkingArgs are the arguments for reviewing a thinking session. +type ReviewThinkingArgs struct { + SessionID string `json:"sessionId"` +} + +// ThinkingHistoryArgs are the arguments for retrieving thinking history. +type ThinkingHistoryArgs struct { + SessionID string `json:"sessionId"` +} + +// deepCopyThoughts creates a deep copy of a slice of thoughts. +func deepCopyThoughts(thoughts []*Thought) []*Thought { + thoughtsCopy := make([]*Thought, len(thoughts)) + for i, t := range thoughts { + t2 := *t + thoughtsCopy[i] = &t2 + } + return thoughtsCopy +} + +// StartThinking begins a new sequential thinking session for a complex problem. +func StartThinking(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[StartThinkingArgs]) (*mcp.CallToolResultFor[any], error) { + args := params.Arguments + + sessionID := args.SessionID + if sessionID == "" { + sessionID = randText() + } + + estimatedSteps := args.EstimatedSteps + if estimatedSteps == 0 { + estimatedSteps = 5 // Default estimate + } + + session := &ThinkingSession{ + ID: sessionID, + Problem: args.Problem, + EstimatedTotal: estimatedSteps, + Status: "active", + Created: time.Now(), + LastActivity: time.Now(), + } + + store.SetSession(session) + + return &mcp.CallToolResultFor[any]{ + Content: []mcp.Content{ + &mcp.TextContent{ + Text: fmt.Sprintf("Started thinking session '%s' for problem: %s\nEstimated steps: %d\nReady for your first thought.", + sessionID, args.Problem, estimatedSteps), + }, + }, + }, nil +} + +// ContinueThinking adds the next thought step, revises a previous step, or creates a branch in the thinking process. +func ContinueThinking(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[ContinueThinkingArgs]) (*mcp.CallToolResultFor[any], error) { + args := params.Arguments + + // Handle revision of existing thought + if args.ReviseStep != nil { + err := store.CompareAndSwap(args.SessionID, func(session *ThinkingSession) (*ThinkingSession, error) { + stepIndex := *args.ReviseStep - 1 + if stepIndex < 0 || stepIndex >= len(session.Thoughts) { + return nil, fmt.Errorf("invalid step number: %d", *args.ReviseStep) + } + + session.Thoughts[stepIndex].Content = args.Thought + session.Thoughts[stepIndex].Revised = true + session.LastActivity = time.Now() + return session, nil + }) + if err != nil { + return nil, err + } + + return &mcp.CallToolResultFor[any]{ + Content: []mcp.Content{ + &mcp.TextContent{ + Text: fmt.Sprintf("Revised step %d in session '%s':\n%s", + *args.ReviseStep, args.SessionID, args.Thought), + }, + }, + }, nil + } + + // Handle branching + if args.CreateBranch { + var branchID string + var branchSession *ThinkingSession + + err := store.CompareAndSwap(args.SessionID, func(session *ThinkingSession) (*ThinkingSession, error) { + branchID = fmt.Sprintf("%s_branch_%d", args.SessionID, len(session.Branches)+1) + session.Branches = append(session.Branches, branchID) + session.LastActivity = time.Now() + + // Create a new session for the branch (deep copy thoughts) + thoughtsCopy := deepCopyThoughts(session.Thoughts) + branchSession = &ThinkingSession{ + ID: branchID, + Problem: session.Problem + " (Alternative branch)", + Thoughts: thoughtsCopy, + CurrentThought: len(session.Thoughts), + EstimatedTotal: session.EstimatedTotal, + Status: "active", + Created: time.Now(), + LastActivity: time.Now(), + } + + return session, nil + }) + if err != nil { + return nil, err + } + + // Save the branch session + store.SetSession(branchSession) + + return &mcp.CallToolResultFor[any]{ + Content: []mcp.Content{ + &mcp.TextContent{ + Text: fmt.Sprintf("Created branch '%s' from session '%s'. You can now continue thinking in either session.", + branchID, args.SessionID), + }, + }, + }, nil + } + + // Add new thought + var thoughtID int + var progress string + var statusMsg string + + err := store.CompareAndSwap(args.SessionID, func(session *ThinkingSession) (*ThinkingSession, error) { + thoughtID = len(session.Thoughts) + 1 + thought := &Thought{ + Index: thoughtID, + Content: args.Thought, + Created: time.Now(), + Revised: false, + } + + session.Thoughts = append(session.Thoughts, thought) + session.CurrentThought = thoughtID + session.LastActivity = time.Now() + + // Update estimated total if provided + if args.EstimatedTotal > 0 { + session.EstimatedTotal = args.EstimatedTotal + } + + // Check if thinking is complete + if args.NextNeeded != nil && !*args.NextNeeded { + session.Status = "completed" + } + + // Prepare response strings + progress = fmt.Sprintf("Step %d", thoughtID) + if session.EstimatedTotal > 0 { + progress += fmt.Sprintf(" of ~%d", session.EstimatedTotal) + } + + if session.Status == "completed" { + statusMsg = "\n✓ Thinking process completed!" + } else { + statusMsg = "\nReady for next thought..." + } + + return session, nil + }) + if err != nil { + return nil, err + } + + return &mcp.CallToolResultFor[any]{ + Content: []mcp.Content{ + &mcp.TextContent{ + Text: fmt.Sprintf("Session '%s' - %s:\n%s%s", + args.SessionID, progress, args.Thought, statusMsg), + }, + }, + }, nil +} + +// ReviewThinking provides a complete review of the thinking process for a session. +func ReviewThinking(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[ReviewThinkingArgs]) (*mcp.CallToolResultFor[any], error) { + args := params.Arguments + + // Get a snapshot of the session to avoid race conditions + sessionSnapshot, exists := store.SessionSnapshot(args.SessionID) + if !exists { + return nil, fmt.Errorf("session %s not found", args.SessionID) + } + + var review strings.Builder + fmt.Fprintf(&review, "=== Thinking Review: %s ===\n", sessionSnapshot.ID) + fmt.Fprintf(&review, "Problem: %s\n", sessionSnapshot.Problem) + fmt.Fprintf(&review, "Status: %s\n", sessionSnapshot.Status) + fmt.Fprintf(&review, "Steps: %d of ~%d\n", len(sessionSnapshot.Thoughts), sessionSnapshot.EstimatedTotal) + + if len(sessionSnapshot.Branches) > 0 { + fmt.Fprintf(&review, "Branches: %s\n", strings.Join(sessionSnapshot.Branches, ", ")) + } + + fmt.Fprintf(&review, "\n--- Thought Sequence ---\n") + + for i, thought := range sessionSnapshot.Thoughts { + status := "" + if thought.Revised { + status = " (revised)" + } + fmt.Fprintf(&review, "%d. %s%s\n", i+1, thought.Content, status) + } + + return &mcp.CallToolResultFor[any]{ + Content: []mcp.Content{ + &mcp.TextContent{ + Text: review.String(), + }, + }, + }, nil +} + +// ThinkingHistory handles resource requests for thinking session data and history. +func ThinkingHistory(ctx context.Context, ss *mcp.ServerSession, params *mcp.ReadResourceParams) (*mcp.ReadResourceResult, error) { + // Extract session ID from URI (e.g., "thinking://session_123") + u, err := url.Parse(params.URI) + if err != nil { + return nil, fmt.Errorf("invalid thinking resource URI: %s", params.URI) + } + if u.Scheme != "thinking" { + return nil, fmt.Errorf("invalid thinking resource URI scheme: %s", u.Scheme) + } + + sessionID := u.Host + if sessionID == "sessions" { + // List all sessions - use snapshot for thread safety + sessions := store.SessionsSnapshot() + data, err := json.MarshalIndent(sessions, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to marshal sessions: %w", err) + } + + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{ + { + URI: params.URI, + MIMEType: "application/json", + Text: string(data), + }, + }, + }, nil + } + + // Get specific session - use snapshot for thread safety + session, exists := store.SessionSnapshot(sessionID) + if !exists { + return nil, fmt.Errorf("session %s not found", sessionID) + } + + data, err := json.MarshalIndent(session, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to marshal session: %w", err) + } + + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{ + { + URI: params.URI, + MIMEType: "application/json", + Text: string(data), + }, + }, + }, nil +} + +// Copied from crypto/rand. +// TODO: once 1.24 is assured, just use crypto/rand. +const base32alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" + +func randText() string { + // ⌈log₃₂ 2¹²⁸⌉ = 26 chars + src := make([]byte, 26) + rand.Read(src) + for i := range src { + src[i] = base32alphabet[src[i]%32] + } + return string(src) +} + +func main() { + flag.Parse() + + server := mcp.NewServer(&mcp.Implementation{Name: "sequential-thinking"}, nil) + + mcp.AddTool(server, &mcp.Tool{ + Name: "start_thinking", + Description: "Begin a new sequential thinking session for a complex problem", + }, StartThinking) + + mcp.AddTool(server, &mcp.Tool{ + Name: "continue_thinking", + Description: "Add the next thought step, revise a previous step, or create a branch", + }, ContinueThinking) + + mcp.AddTool(server, &mcp.Tool{ + Name: "review_thinking", + Description: "Review the complete thinking process for a session", + }, ReviewThinking) + + server.AddResource(&mcp.Resource{ + Name: "thinking_sessions", + Description: "Access thinking session data and history", + URI: "thinking://sessions", + MIMEType: "application/json", + }, ThinkingHistory) + + if *httpAddr != "" { + handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, nil) + log.Printf("Sequential Thinking MCP server listening at %s", *httpAddr) + if err := http.ListenAndServe(*httpAddr, handler); err != nil { + log.Fatal(err) + } + } else { + t := mcp.NewLoggingTransport(mcp.NewStdioTransport(), os.Stderr) + if err := server.Run(context.Background(), t); err != nil { + log.Printf("Server failed: %v", err) + } + } +} diff --git a/examples/sequentialthinking/main_test.go b/examples/sequentialthinking/main_test.go new file mode 100644 index 00000000..13cd9ccb --- /dev/null +++ b/examples/sequentialthinking/main_test.go @@ -0,0 +1,548 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "encoding/json" + "strings" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func TestStartThinking(t *testing.T) { + // Reset store for clean test + store = NewSessionStore() + + ctx := context.Background() + + args := StartThinkingArgs{ + Problem: "How to implement a binary search algorithm", + SessionID: "test_session", + EstimatedSteps: 5, + } + + params := &mcp.CallToolParamsFor[StartThinkingArgs]{ + Name: "start_thinking", + Arguments: args, + } + + result, err := StartThinking(ctx, nil, params) + if err != nil { + t.Fatalf("StartThinking() error = %v", err) + } + + if len(result.Content) == 0 { + t.Fatal("No content in result") + } + + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + if !strings.Contains(textContent.Text, "test_session") { + t.Error("Result should contain session ID") + } + + if !strings.Contains(textContent.Text, "How to implement a binary search algorithm") { + t.Error("Result should contain the problem statement") + } + + // Verify session was stored + session, exists := store.Session("test_session") + if !exists { + t.Fatal("Session was not stored") + } + + if session.Problem != args.Problem { + t.Errorf("Expected problem %s, got %s", args.Problem, session.Problem) + } + + if session.EstimatedTotal != 5 { + t.Errorf("Expected estimated total 5, got %d", session.EstimatedTotal) + } + + if session.Status != "active" { + t.Errorf("Expected status 'active', got %s", session.Status) + } +} + +func TestContinueThinking(t *testing.T) { + // Reset store and create initial session + store = NewSessionStore() + + // First start a thinking session + ctx := context.Background() + startArgs := StartThinkingArgs{ + Problem: "Test problem", + SessionID: "test_continue", + EstimatedSteps: 3, + } + + startParams := &mcp.CallToolParamsFor[StartThinkingArgs]{ + Name: "start_thinking", + Arguments: startArgs, + } + + _, err := StartThinking(ctx, nil, startParams) + if err != nil { + t.Fatalf("StartThinking() error = %v", err) + } + + // Now continue thinking + continueArgs := ContinueThinkingArgs{ + SessionID: "test_continue", + Thought: "First thought: I need to understand the problem", + } + + continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ + Name: "continue_thinking", + Arguments: continueArgs, + } + + result, err := ContinueThinking(ctx, nil, continueParams) + if err != nil { + t.Fatalf("ContinueThinking() error = %v", err) + } + + // Verify result + if len(result.Content) == 0 { + t.Fatal("No content in result") + } + + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + if !strings.Contains(textContent.Text, "Step 1") { + t.Error("Result should contain step number") + } + + // Verify session was updated + session, exists := store.Session("test_continue") + if !exists { + t.Fatal("Session not found") + } + + if len(session.Thoughts) != 1 { + t.Errorf("Expected 1 thought, got %d", len(session.Thoughts)) + } + + if session.Thoughts[0].Content != continueArgs.Thought { + t.Errorf("Expected thought content %s, got %s", continueArgs.Thought, session.Thoughts[0].Content) + } + + if session.CurrentThought != 1 { + t.Errorf("Expected current thought 1, got %d", session.CurrentThought) + } +} + +func TestContinueThinkingWithCompletion(t *testing.T) { + // Reset store and create initial session + store = NewSessionStore() + + ctx := context.Background() + startArgs := StartThinkingArgs{ + Problem: "Simple test", + SessionID: "test_completion", + } + + startParams := &mcp.CallToolParamsFor[StartThinkingArgs]{ + Name: "start_thinking", + Arguments: startArgs, + } + + _, err := StartThinking(ctx, nil, startParams) + if err != nil { + t.Fatalf("StartThinking() error = %v", err) + } + + // Continue with completion flag + nextNeeded := false + continueArgs := ContinueThinkingArgs{ + SessionID: "test_completion", + Thought: "Final thought", + NextNeeded: &nextNeeded, + } + + continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ + Name: "continue_thinking", + Arguments: continueArgs, + } + + result, err := ContinueThinking(ctx, nil, continueParams) + if err != nil { + t.Fatalf("ContinueThinking() error = %v", err) + } + + // Check completion message + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + if !strings.Contains(textContent.Text, "completed") { + t.Error("Result should indicate completion") + } + + // Verify session status + session, exists := store.Session("test_completion") + if !exists { + t.Fatal("Session not found") + } + + if session.Status != "completed" { + t.Errorf("Expected status 'completed', got %s", session.Status) + } +} + +func TestContinueThinkingRevision(t *testing.T) { + // Setup session with existing thoughts + store = NewSessionStore() + session := &ThinkingSession{ + ID: "test_revision", + Problem: "Test problem", + Thoughts: []*Thought{ + {Index: 1, Content: "Original thought", Created: time.Now()}, + {Index: 2, Content: "Second thought", Created: time.Now()}, + }, + CurrentThought: 2, + EstimatedTotal: 3, + Status: "active", + Created: time.Now(), + LastActivity: time.Now(), + } + store.SetSession(session) + + ctx := context.Background() + reviseStep := 1 + continueArgs := ContinueThinkingArgs{ + SessionID: "test_revision", + Thought: "Revised first thought", + ReviseStep: &reviseStep, + } + + continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ + Name: "continue_thinking", + Arguments: continueArgs, + } + + result, err := ContinueThinking(ctx, nil, continueParams) + if err != nil { + t.Fatalf("ContinueThinking() error = %v", err) + } + + // Verify revision message + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + if !strings.Contains(textContent.Text, "Revised step 1") { + t.Error("Result should indicate revision") + } + + // Verify thought was revised + updatedSession, _ := store.Session("test_revision") + if updatedSession.Thoughts[0].Content != "Revised first thought" { + t.Error("First thought should be revised") + } + + if !updatedSession.Thoughts[0].Revised { + t.Error("First thought should be marked as revised") + } +} + +func TestContinueThinkingBranching(t *testing.T) { + // Setup session with existing thoughts + store = NewSessionStore() + session := &ThinkingSession{ + ID: "test_branch", + Problem: "Test problem", + Thoughts: []*Thought{ + {Index: 1, Content: "First thought", Created: time.Now()}, + }, + CurrentThought: 1, + EstimatedTotal: 3, + Status: "active", + Created: time.Now(), + LastActivity: time.Now(), + Branches: []string{}, + } + store.SetSession(session) + + ctx := context.Background() + continueArgs := ContinueThinkingArgs{ + SessionID: "test_branch", + Thought: "Alternative approach", + CreateBranch: true, + } + + continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ + Name: "continue_thinking", + Arguments: continueArgs, + } + + result, err := ContinueThinking(ctx, nil, continueParams) + if err != nil { + t.Fatalf("ContinueThinking() error = %v", err) + } + + // Verify branch creation message + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + if !strings.Contains(textContent.Text, "Created branch") { + t.Error("Result should indicate branch creation") + } + + // Verify branch was created + updatedSession, _ := store.Session("test_branch") + if len(updatedSession.Branches) != 1 { + t.Errorf("Expected 1 branch, got %d", len(updatedSession.Branches)) + } + + branchID := updatedSession.Branches[0] + if !strings.Contains(branchID, "test_branch_branch_") { + t.Error("Branch ID should contain parent session ID") + } + + // Verify branch session exists + branchSession, exists := store.Session(branchID) + if !exists { + t.Fatal("Branch session should exist") + } + + if len(branchSession.Thoughts) != 1 { + t.Error("Branch should inherit parent thoughts") + } +} + +func TestReviewThinking(t *testing.T) { + // Setup session with thoughts + store = NewSessionStore() + session := &ThinkingSession{ + ID: "test_review", + Problem: "Complex problem", + Thoughts: []*Thought{ + {Index: 1, Content: "First thought", Created: time.Now(), Revised: false}, + {Index: 2, Content: "Second thought", Created: time.Now(), Revised: true}, + {Index: 3, Content: "Final thought", Created: time.Now(), Revised: false}, + }, + CurrentThought: 3, + EstimatedTotal: 3, + Status: "completed", + Created: time.Now(), + LastActivity: time.Now(), + Branches: []string{"test_review_branch_1"}, + } + store.SetSession(session) + + ctx := context.Background() + reviewArgs := ReviewThinkingArgs{ + SessionID: "test_review", + } + + reviewParams := &mcp.CallToolParamsFor[ReviewThinkingArgs]{ + Name: "review_thinking", + Arguments: reviewArgs, + } + + result, err := ReviewThinking(ctx, nil, reviewParams) + if err != nil { + t.Fatalf("ReviewThinking() error = %v", err) + } + + // Verify review content + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + reviewText := textContent.Text + + if !strings.Contains(reviewText, "test_review") { + t.Error("Review should contain session ID") + } + + if !strings.Contains(reviewText, "Complex problem") { + t.Error("Review should contain problem") + } + + if !strings.Contains(reviewText, "completed") { + t.Error("Review should contain status") + } + + if !strings.Contains(reviewText, "Steps: 3 of ~3") { + t.Error("Review should contain step count") + } + + if !strings.Contains(reviewText, "First thought") { + t.Error("Review should contain first thought") + } + + if !strings.Contains(reviewText, "(revised)") { + t.Error("Review should indicate revised thoughts") + } + + if !strings.Contains(reviewText, "test_review_branch_1") { + t.Error("Review should list branches") + } +} + +func TestThinkingHistory(t *testing.T) { + // Setup test sessions + store = NewSessionStore() + session1 := &ThinkingSession{ + ID: "session1", + Problem: "Problem 1", + Thoughts: []*Thought{{Index: 1, Content: "Thought 1", Created: time.Now()}}, + CurrentThought: 1, + EstimatedTotal: 2, + Status: "active", + Created: time.Now(), + LastActivity: time.Now(), + } + session2 := &ThinkingSession{ + ID: "session2", + Problem: "Problem 2", + Thoughts: []*Thought{{Index: 1, Content: "Thought 1", Created: time.Now()}}, + CurrentThought: 1, + EstimatedTotal: 3, + Status: "completed", + Created: time.Now(), + LastActivity: time.Now(), + } + store.SetSession(session1) + store.SetSession(session2) + + ctx := context.Background() + + // Test listing all sessions + listParams := &mcp.ReadResourceParams{ + URI: "thinking://sessions", + } + + result, err := ThinkingHistory(ctx, nil, listParams) + if err != nil { + t.Fatalf("ThinkingHistory() error = %v", err) + } + + if len(result.Contents) != 1 { + t.Fatal("Expected 1 content item") + } + + content := result.Contents[0] + if content.MIMEType != "application/json" { + t.Error("Expected JSON MIME type") + } + + // Parse and verify sessions list + var sessions []*ThinkingSession + err = json.Unmarshal([]byte(content.Text), &sessions) + if err != nil { + t.Fatalf("Failed to parse sessions JSON: %v", err) + } + + if len(sessions) != 2 { + t.Errorf("Expected 2 sessions, got %d", len(sessions)) + } + + // Test getting specific session + sessionParams := &mcp.ReadResourceParams{ + URI: "thinking://session1", + } + + result, err = ThinkingHistory(ctx, nil, sessionParams) + if err != nil { + t.Fatalf("ThinkingHistory() error = %v", err) + } + + var retrievedSession ThinkingSession + err = json.Unmarshal([]byte(result.Contents[0].Text), &retrievedSession) + if err != nil { + t.Fatalf("Failed to parse session JSON: %v", err) + } + + if retrievedSession.ID != "session1" { + t.Errorf("Expected session ID 'session1', got %s", retrievedSession.ID) + } + + if retrievedSession.Problem != "Problem 1" { + t.Errorf("Expected problem 'Problem 1', got %s", retrievedSession.Problem) + } +} + +func TestInvalidOperations(t *testing.T) { + store = NewSessionStore() + ctx := context.Background() + + // Test continue thinking with non-existent session + continueArgs := ContinueThinkingArgs{ + SessionID: "nonexistent", + Thought: "Some thought", + } + + continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ + Name: "continue_thinking", + Arguments: continueArgs, + } + + _, err := ContinueThinking(ctx, nil, continueParams) + if err == nil { + t.Error("Expected error for non-existent session") + } + + // Test review with non-existent session + reviewArgs := ReviewThinkingArgs{ + SessionID: "nonexistent", + } + + reviewParams := &mcp.CallToolParamsFor[ReviewThinkingArgs]{ + Name: "review_thinking", + Arguments: reviewArgs, + } + + _, err = ReviewThinking(ctx, nil, reviewParams) + if err == nil { + t.Error("Expected error for non-existent session in review") + } + + // Test invalid revision step + session := &ThinkingSession{ + ID: "test_invalid", + Problem: "Test", + Thoughts: []*Thought{{Index: 1, Content: "Thought", Created: time.Now()}}, + CurrentThought: 1, + EstimatedTotal: 2, + Status: "active", + Created: time.Now(), + LastActivity: time.Now(), + } + store.SetSession(session) + + reviseStep := 5 // Invalid step number + invalidReviseArgs := ContinueThinkingArgs{ + SessionID: "test_invalid", + Thought: "Revised", + ReviseStep: &reviseStep, + } + + invalidReviseParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ + Name: "continue_thinking", + Arguments: invalidReviseArgs, + } + + _, err = ContinueThinking(ctx, nil, invalidReviseParams) + if err == nil { + t.Error("Expected error for invalid revision step") + } +} From b780f1b3e1f8dc9ee85e61366ded9c03fe1f4e87 Mon Sep 17 00:00:00 2001 From: Ryuji Iwata Date: Mon, 14 Jul 2025 21:21:30 +0900 Subject: [PATCH 023/221] Fix jsonrpc path and url (#128) jsonrpc's github path notation and pkg.go.dev link url are broken. --- README.md | 2 +- internal/readme/README.src.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ef74035e..45b16f9c 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ The SDK consists of three importable packages: package provides an implementation of [JSON Schema](https://json-schema.org/), used for MCP tool input and output schema. - The - [`github.com/modelcontextprotocol/jsonrpc`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonschema) package is for users implementing + [`github.com/modelcontextprotocol/go-sdk/jsonrpc`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonrpc) package is for users implementing their own transports. diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index c03d74d6..5f0f1520 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -42,7 +42,7 @@ The SDK consists of three importable packages: package provides an implementation of [JSON Schema](https://json-schema.org/), used for MCP tool input and output schema. - The - [`github.com/modelcontextprotocol/jsonrpc`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonschema) package is for users implementing + [`github.com/modelcontextprotocol/go-sdk/jsonrpc`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonrpc) package is for users implementing their own transports. From a26e1dc25351c547a2d89e57ee075b5db17722a4 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Mon, 14 Jul 2025 10:02:33 -0400 Subject: [PATCH 024/221] mcp: fail when streamable http client encounters unsupported content (#130) The streamable HTTP client is not fully implemented, but did not fail with meaningful error messages when encountering unsupported content types from other servers. This may be straightforward to fix, but for now at least provide a meaningful error message. For modelcontextprotocol/go-sdk#129. --- mcp/streamable.go | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 11d70a38..d371c873 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -645,8 +645,7 @@ type streamableClientConn struct { mu sync.Mutex protocolVersion string _sessionID string - // bodies map[*http.Response]io.Closer - err error + err error } func (c *streamableClientConn) setProtocolVersion(s string) { @@ -741,10 +740,16 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string } sessionID = resp.Header.Get(sessionIDHeader) - if resp.Header.Get("Content-Type") == "text/event-stream" { + switch ct := resp.Header.Get("Content-Type"); ct { + case "text/event-stream": go s.handleSSE(resp) - } else { + case "application/json": + // TODO: read the body and send to s.incoming (in a select that also receives from s.done). + resp.Body.Close() + return "", fmt.Errorf("streamable HTTP client does not yet support raw JSON responses") + default: resp.Body.Close() + return "", fmt.Errorf("unsupported content type %q", ct) } return sessionID, nil } @@ -760,7 +765,11 @@ func (s *streamableClientConn) handleSSE(resp *http.Response) { // TODO: surface this error; possibly break the stream return } - s.incoming <- evt.data + select { + case <-s.done: + return + case s.incoming <- evt.data: + } } }() From bf7e518215fff7cf505c4a40f5d1a2843d0c63be Mon Sep 17 00:00:00 2001 From: Qihang Hu Date: Mon, 14 Jul 2025 23:39:30 +0800 Subject: [PATCH 025/221] docs: fix the service example in the readme and examples/hello (#127) Fix an example using mcp tag instead of jsonschema. --- README.md | 2 +- examples/hello/main.go | 2 +- internal/readme/server/server.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 45b16f9c..037885e5 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ import ( ) type HiParams struct { - Name string `json:"name", mcp:"the name of the person to greet"` + Name string `json:"name" jsonschema:"the name of the person to greet"` } func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[HiParams]) (*mcp.CallToolResultFor[any], error) { diff --git a/examples/hello/main.go b/examples/hello/main.go index 84cff1e0..f01a6c99 100644 --- a/examples/hello/main.go +++ b/examples/hello/main.go @@ -19,7 +19,7 @@ import ( var httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") type HiArgs struct { - Name string `json:"name" mcp:"the name to say hi to"` + Name string `json:"name" jsonschema:"the name to say hi to"` } func SayHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[HiArgs]) (*mcp.CallToolResultFor[struct{}], error) { diff --git a/internal/readme/server/server.go b/internal/readme/server/server.go index 7c025412..8901e773 100644 --- a/internal/readme/server/server.go +++ b/internal/readme/server/server.go @@ -13,7 +13,7 @@ import ( ) type HiParams struct { - Name string `json:"name", mcp:"the name of the person to greet"` + Name string `json:"name" jsonschema:"the name of the person to greet"` } func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[HiParams]) (*mcp.CallToolResultFor[any], error) { From de4b788e59958929f49cd445b71c014e1c2ef8bc Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 14 Jul 2025 12:47:58 -0400 Subject: [PATCH 026/221] mcp: fix conformance tests (#132) - Use new AddXXX API. - Update goldens to reflect new capability logic. Fixes #131. --- mcp/conformance_test.go | 75 +++++++++++++------ mcp/mcp_test.go | 28 ++++--- mcp/testdata/conformance/server/prompts.txtar | 6 -- .../conformance/server/resources.txtar | 6 -- mcp/testdata/conformance/server/tools.txtar | 6 -- .../conformance/server/version-latest.txtar | 13 +--- .../conformance/server/version-older.txtar | 11 +-- 7 files changed, 70 insertions(+), 75 deletions(-) diff --git a/mcp/conformance_test.go b/mcp/conformance_test.go index 3be54a3e..883d8a89 100644 --- a/mcp/conformance_test.go +++ b/mcp/conformance_test.go @@ -23,6 +23,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" "golang.org/x/tools/txtar" ) @@ -46,12 +47,12 @@ var update = flag.Bool("update", false, "if set, update conformance test data") // with -update to have the test runner update the expected output, which may // be client or server depending on the perspective of the test. type conformanceTest struct { - name string // test name - path string // path to test file - archive *txtar.Archive // raw archive, for updating - tools, prompts, resources []string // named features to include - client []JSONRPCMessage // client messages - server []JSONRPCMessage // server messages + name string // test name + path string // path to test file + archive *txtar.Archive // raw archive, for updating + tools, prompts, resources []string // named features to include + client []jsonrpc.Message // client messages + server []jsonrpc.Message // server messages } // TODO(rfindley): add client conformance tests. @@ -100,10 +101,36 @@ func TestServerConformance(t *testing.T) { func runServerTest(t *testing.T, test *conformanceTest) { ctx := t.Context() // Construct the server based on features listed in the test. - s := NewServer("testServer", "v1.0.0", nil) - add(tools, s.AddTools, test.tools...) - add(prompts, s.AddPrompts, test.prompts...) - add(resources, s.AddResources, test.resources...) + s := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + for _, tn := range test.tools { + switch tn { + case "greet": + AddTool(s, &Tool{ + Name: "greet", + Description: "say hi", + }, sayHi) + default: + t.Fatalf("unknown tool %q", tn) + } + } + for _, pn := range test.prompts { + switch pn { + case "code_review": + s.AddPrompt(codeReviewPrompt, codReviewPromptHandler) + default: + t.Fatalf("unknown prompt %q", pn) + } + } + for _, rn := range test.resources { + switch rn { + case "info.txt": + s.AddResource(resource1, readHandler) + case "info": + s.AddResource(resource3, handleEmbeddedResource) + default: + t.Fatalf("unknown resource %q", rn) + } + } // Connect the server, and connect the client stream, // but don't connect an actual client. @@ -117,24 +144,24 @@ func runServerTest(t *testing.T, test *conformanceTest) { t.Fatal(err) } - writeMsg := func(msg JSONRPCMessage) { + writeMsg := func(msg jsonrpc.Message) { if err := cStream.Write(ctx, msg); err != nil { t.Fatalf("Write failed: %v", err) } } var ( - serverMessages []JSONRPCMessage - outRequests []*JSONRPCRequest - outResponses []*JSONRPCResponse + serverMessages []jsonrpc.Message + outRequests []*jsonrpc.Request + outResponses []*jsonrpc.Response ) // Separate client requests and responses; we use them differently. for _, msg := range test.client { switch msg := msg.(type) { - case *JSONRPCRequest: + case *jsonrpc.Request: outRequests = append(outRequests, msg) - case *JSONRPCResponse: + case *jsonrpc.Response: outResponses = append(outResponses, msg) default: t.Fatalf("bad message type %T", msg) @@ -143,7 +170,7 @@ func runServerTest(t *testing.T, test *conformanceTest) { // nextResponse handles incoming requests and notifications, and returns the // next incoming response. - nextResponse := func() (*JSONRPCResponse, error, bool) { + nextResponse := func() (*jsonrpc.Response, error, bool) { for { msg, err := cStream.Read(ctx) if err != nil { @@ -156,7 +183,7 @@ func runServerTest(t *testing.T, test *conformanceTest) { return nil, err, false } serverMessages = append(serverMessages, msg) - if req, ok := msg.(*JSONRPCRequest); ok && req.ID.IsValid() { + if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() { // Pair up the next outgoing response with this request. // We assume requests arrive in the same order every time. if len(outResponses) == 0 { @@ -167,7 +194,7 @@ func runServerTest(t *testing.T, test *conformanceTest) { outResponses = outResponses[1:] continue } - return msg.(*JSONRPCResponse), nil, true + return msg.(*jsonrpc.Response), nil, true } } @@ -191,7 +218,7 @@ func runServerTest(t *testing.T, test *conformanceTest) { // There might be more notifications or requests, but there shouldn't be more // responses. // Run this in a goroutine so the current thread can wait for it. - var extra *JSONRPCResponse + var extra *jsonrpc.Response go func() { extra, err, _ = nextResponse() }() @@ -240,8 +267,8 @@ func runServerTest(t *testing.T, test *conformanceTest) { t.Fatalf("os.WriteFile(%q) failed: %v", test.path, err) } } else { - // JSONRPCMessages are not comparable, so we instead compare lines of JSON. - transform := cmpopts.AcyclicTransformer("toJSON", func(msg JSONRPCMessage) []string { + // jsonrpc.Messages are not comparable, so we instead compare lines of JSON. + transform := cmpopts.AcyclicTransformer("toJSON", func(msg jsonrpc.Message) []string { encoded, err := jsonrpc2.EncodeIndent(msg, "", "\t") if err != nil { t.Fatal(err) @@ -271,9 +298,9 @@ func loadConformanceTest(dir, path string) (*conformanceTest, error) { } // decodeMessages loads JSON-RPC messages from the archive file. - decodeMessages := func(data []byte) ([]JSONRPCMessage, error) { + decodeMessages := func(data []byte) ([]jsonrpc.Message, error) { dec := json.NewDecoder(bytes.NewReader(data)) - var res []JSONRPCMessage + var res []jsonrpc.Message for dec.More() { var raw json.RawMessage if err := dec.Decode(&raw); err != nil { diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 1952e8d0..7da2b857 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -39,6 +39,21 @@ func sayHi(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[hiP return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + params.Arguments.Name}}}, nil } +var codeReviewPrompt = &Prompt{ + Name: "code_review", + Description: "do a code review", + Arguments: []*PromptArgument{{Name: "Code", Required: true}}, +} + +func codReviewPromptHandler(_ context.Context, _ *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { + return &GetPromptResult{ + Description: "Code review prompt", + Messages: []*PromptMessage{ + {Role: "user", Content: &TextContent{Text: "Please review the following code: " + params.Arguments["Code"]}}, + }, + }, nil +} + func TestEndToEnd(t *testing.T) { ctx := context.Background() var ct, st Transport = NewInMemoryTransports() @@ -73,18 +88,7 @@ func TestEndToEnd(t *testing.T) { func(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { return nil, errTestFailure }) - s.AddPrompt(&Prompt{ - Name: "code_review", - Description: "do a code review", - Arguments: []*PromptArgument{{Name: "Code", Required: true}}, - }, func(_ context.Context, _ *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { - return &GetPromptResult{ - Description: "Code review prompt", - Messages: []*PromptMessage{ - {Role: "user", Content: &TextContent{Text: "Please review the following code: " + params.Arguments["Code"]}}, - }, - }, nil - }) + s.AddPrompt(codeReviewPrompt, codReviewPromptHandler) s.AddPrompt(&Prompt{Name: "fail"}, func(_ context.Context, _ *ServerSession, _ *GetPromptParams) (*GetPromptResult, error) { return nil, errTestFailure }) diff --git a/mcp/testdata/conformance/server/prompts.txtar b/mcp/testdata/conformance/server/prompts.txtar index 6168ce8e..078e0915 100644 --- a/mcp/testdata/conformance/server/prompts.txtar +++ b/mcp/testdata/conformance/server/prompts.txtar @@ -29,12 +29,6 @@ code_review "logging": {}, "prompts": { "listChanged": true - }, - "resources": { - "listChanged": true - }, - "tools": { - "listChanged": true } }, "protocolVersion": "2024-11-05", diff --git a/mcp/testdata/conformance/server/resources.txtar b/mcp/testdata/conformance/server/resources.txtar index 3e7031ad..5bb5515d 100644 --- a/mcp/testdata/conformance/server/resources.txtar +++ b/mcp/testdata/conformance/server/resources.txtar @@ -47,14 +47,8 @@ info.txt "capabilities": { "completions": {}, "logging": {}, - "prompts": { - "listChanged": true - }, "resources": { "listChanged": true - }, - "tools": { - "listChanged": true } }, "protocolVersion": "2024-11-05", diff --git a/mcp/testdata/conformance/server/tools.txtar b/mcp/testdata/conformance/server/tools.txtar index a43cd075..07ad942a 100644 --- a/mcp/testdata/conformance/server/tools.txtar +++ b/mcp/testdata/conformance/server/tools.txtar @@ -30,12 +30,6 @@ greet "capabilities": { "completions": {}, "logging": {}, - "prompts": { - "listChanged": true - }, - "resources": { - "listChanged": true - }, "tools": { "listChanged": true } diff --git a/mcp/testdata/conformance/server/version-latest.txtar b/mcp/testdata/conformance/server/version-latest.txtar index 760bf8b7..89454fb3 100644 --- a/mcp/testdata/conformance/server/version-latest.txtar +++ b/mcp/testdata/conformance/server/version-latest.txtar @@ -19,18 +19,9 @@ response with its latest supported version. "result": { "capabilities": { "completions": {}, - "logging": {}, - "prompts": { - "listChanged": true - }, - "resources": { - "listChanged": true - }, - "tools": { - "listChanged": true - } + "logging": {} }, - "protocolVersion": "2025-03-26", + "protocolVersion": "2025-06-18", "serverInfo": { "name": "testServer", "version": "v1.0.0" diff --git a/mcp/testdata/conformance/server/version-older.txtar b/mcp/testdata/conformance/server/version-older.txtar index 97f7b79b..55240954 100644 --- a/mcp/testdata/conformance/server/version-older.txtar +++ b/mcp/testdata/conformance/server/version-older.txtar @@ -19,16 +19,7 @@ support. "result": { "capabilities": { "completions": {}, - "logging": {}, - "prompts": { - "listChanged": true - }, - "resources": { - "listChanged": true - }, - "tools": { - "listChanged": true - } + "logging": {} }, "protocolVersion": "2024-11-05", "serverInfo": { From dd49b7039b69498b242fe0ffea5016b734e6181c Mon Sep 17 00:00:00 2001 From: Robert Jackson Date: Thu, 17 Jul 2025 07:44:15 -0700 Subject: [PATCH 027/221] mcp: Add example of using Middleware for logging purposes (#58) I needed to implement some logging for our experimental server that we are building out using the SDK, and figured I could propose an example that might make it easier for folks in the future. For #33. --- mcp/example_middleware_test.go | 141 +++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 mcp/example_middleware_test.go diff --git a/mcp/example_middleware_test.go b/mcp/example_middleware_test.go new file mode 100644 index 00000000..1328473a --- /dev/null +++ b/mcp/example_middleware_test.go @@ -0,0 +1,141 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "context" + "fmt" + "log/slog" + "os" + "time" + + "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// This example demonstrates server side logging using the mcp.Middleware system. +func Example_loggingMiddleware() { + // Create a logger for demonstration purposes. + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelInfo, + ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr { + // Simplify timestamp format for consistent output. + if a.Key == slog.TimeKey { + return slog.String("time", "2025-01-01T00:00:00Z") + } + return a + }, + })) + + loggingMiddleware := func(next mcp.MethodHandler[*mcp.ServerSession]) mcp.MethodHandler[*mcp.ServerSession] { + return func( + ctx context.Context, + session *mcp.ServerSession, + method string, + params mcp.Params, + ) (mcp.Result, error) { + logger.Info("MCP method started", + "method", method, + "session_id", session.ID(), + "has_params", params != nil, + ) + + start := time.Now() + + result, err := next(ctx, session, method, params) + + duration := time.Since(start) + + if err != nil { + logger.Error("MCP method failed", + "method", method, + "session_id", session.ID(), + "duration_ms", duration.Milliseconds(), + "err", err, + ) + } else { + logger.Info("MCP method completed", + "method", method, + "session_id", session.ID(), + "duration_ms", duration.Milliseconds(), + "has_result", result != nil, + ) + } + + return result, err + } + } + + // Create server with middleware + server := mcp.NewServer(&mcp.Implementation{Name: "logging-example"}, nil) + server.AddReceivingMiddleware(loggingMiddleware) + + // Add a simple tool + server.AddTool( + &mcp.Tool{ + Name: "greet", + Description: "Greet someone with logging.", + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": { + Type: "string", + Description: "Name to greet", + }, + }, + Required: []string{"name"}, + }, + }, + func( + ctx context.Context, + ss *mcp.ServerSession, + params *mcp.CallToolParamsFor[map[string]any], + ) (*mcp.CallToolResultFor[any], error) { + name, ok := params.Arguments["name"].(string) + if !ok { + return nil, fmt.Errorf("name parameter is required and must be a string") + } + + message := fmt.Sprintf("Hello, %s!", name) + return &mcp.CallToolResultFor[any]{ + Content: []mcp.Content{ + &mcp.TextContent{Text: message}, + }, + }, nil + }, + ) + + // Create client-server connection for demonstration + client := mcp.NewClient(&mcp.Implementation{Name: "test-client"}, nil) + clientTransport, serverTransport := mcp.NewInMemoryTransports() + + ctx := context.Background() + + // Connect server and client + serverSession, _ := server.Connect(ctx, serverTransport) + defer serverSession.Close() + + clientSession, _ := client.Connect(ctx, clientTransport) + defer clientSession.Close() + + // Call the tool to demonstrate logging + result, _ := clientSession.CallTool(ctx, &mcp.CallToolParams{ + Name: "greet", + Arguments: map[string]any{ + "name": "World", + }, + }) + + fmt.Printf("Tool result: %s\n", result.Content[0].(*mcp.TextContent).Text) + + // Output: + // time=2025-01-01T00:00:00Z level=INFO msg="MCP method started" method=initialize session_id="" has_params=true + // time=2025-01-01T00:00:00Z level=INFO msg="MCP method completed" method=initialize session_id="" duration_ms=0 has_result=true + // time=2025-01-01T00:00:00Z level=INFO msg="MCP method started" method=notifications/initialized session_id="" has_params=true + // time=2025-01-01T00:00:00Z level=INFO msg="MCP method completed" method=notifications/initialized session_id="" duration_ms=0 has_result=false + // time=2025-01-01T00:00:00Z level=INFO msg="MCP method started" method=tools/call session_id="" has_params=true + // time=2025-01-01T00:00:00Z level=INFO msg="MCP method completed" method=tools/call session_id="" duration_ms=0 has_result=true + // Tool result: Hello, World! +} From d437c81cbeeaecde495897ce843905cb1713585f Mon Sep 17 00:00:00 2001 From: CSK <73425927+cr2007@users.noreply.github.com> Date: Fri, 18 Jul 2025 16:43:47 +0400 Subject: [PATCH 028/221] docs: Enhance README formatting and fix relative links (#146) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR improves the formatting and internal link structure of both the main `README.md` and the source file at `internal/readme/README.src.md`. Changes include: ### 🛠 Changes Made * Replaced plain-text warning with a GitHub-style admonition block (`> [!WARNING]`) for improved visibility and consistency with other docs. Example: > [!WARNING] > The SDK should be considered unreleased, and is currently unstable > and subject to breaking changes. Please test it out and file bug reports or API > proposals, but don't use it in real projects. See the issue tracker for known > issues and missing features. We aim to release a stable version of the SDK in > August, 2025. * Updated markdown links to use cleaner relative paths: * Changed `CONTRIBUTING.md` reference to `[CONTRIBUTING.md](/CONTRIBUTING.md)` * Updated `examples/` reference to `[examples/](/examples/)` * Ensured that all updates were mirrored in both the generated and source README files. ### 📌 Why It Matters These changes improve the readability and usability of the documentation, especially when viewed on GitHub. The updated formatting helps call attention to important warnings, and fixing relative paths ensures that links work reliably across platforms and contexts. --- README.md | 15 ++++++++------- internal/readme/README.src.md | 15 ++++++++------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 037885e5..d899afdd 100644 --- a/README.md +++ b/README.md @@ -12,11 +12,12 @@ https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.2.0) for details. This repository contains an unreleased implementation of the official Go software development kit (SDK) for the Model Context Protocol (MCP). -**WARNING**: The SDK should be considered unreleased, and is currently unstable -and subject to breaking changes. Please test it out and file bug reports or API -proposals, but don't use it in real projects. See the issue tracker for known -issues and missing features. We aim to release a stable version of the SDK in -August, 2025. +> [!WARNING] +> The SDK should be considered unreleased, and is currently unstable +> and subject to breaking changes. Please test it out and file bug reports or API +> proposals, but don't use it in real projects. See the issue tracker for known +> issues and missing features. We aim to release a stable version of the SDK in +> August, 2025. ## Design @@ -28,7 +29,7 @@ Further design discussion should occur in [issues](https://github.com/modelcontextprotocol/go-sdk/issues) (for concrete proposals) or [discussions](https://github.com/modelcontextprotocol/go-sdk/discussions) for -open-ended discussion. See CONTRIBUTING.md for details. +open-ended discussion. See [CONTRIBUTING.md](/CONTRIBUTING.md) for details. ## Package documentation @@ -130,7 +131,7 @@ func main() { } ``` -The `examples/` directory contains more example clients and servers. +The [`examples/`](/examples/) directory contains more example clients and servers. ## Acknowledgements diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index 5f0f1520..0e239f81 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -11,11 +11,12 @@ https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.2.0) for details. This repository contains an unreleased implementation of the official Go software development kit (SDK) for the Model Context Protocol (MCP). -**WARNING**: The SDK should be considered unreleased, and is currently unstable -and subject to breaking changes. Please test it out and file bug reports or API -proposals, but don't use it in real projects. See the issue tracker for known -issues and missing features. We aim to release a stable version of the SDK in -August, 2025. +> [!WARNING] +> The SDK should be considered unreleased, and is currently unstable +> and subject to breaking changes. Please test it out and file bug reports or API +> proposals, but don't use it in real projects. See the issue tracker for known +> issues and missing features. We aim to release a stable version of the SDK in +> August, 2025. ## Design @@ -27,7 +28,7 @@ Further design discussion should occur in [issues](https://github.com/modelcontextprotocol/go-sdk/issues) (for concrete proposals) or [discussions](https://github.com/modelcontextprotocol/go-sdk/discussions) for -open-ended discussion. See CONTRIBUTING.md for details. +open-ended discussion. See [CONTRIBUTING.md](/CONTRIBUTING.md) for details. ## Package documentation @@ -58,7 +59,7 @@ with its client over stdin/stdout: %include server/server.go - -The `examples/` directory contains more example clients and servers. +The [`examples/`](/examples/) directory contains more example clients and servers. ## Acknowledgements From dced3e4503f474bc36b48dc3b1038aec2df7ffd8 Mon Sep 17 00:00:00 2001 From: Martin Emde Date: Sat, 19 Jul 2025 14:08:56 -0700 Subject: [PATCH 029/221] mcp: don't omit required fields in ImageContent and AudioContent (#95) Update `ImageContent` and `AudioContent` to use an inline wire format to ensure that required fields are not omitted during JSON marshaling, matching TypeScript schema requirements. This follows the TextContent approach. - `ImageContent` requires `data` and `mimeType` fields. - `AudioContent` requires `data` and `mimeType` fields. --- mcp/content.go | 36 ++++++++++++++++++++++++++++++------ mcp/content_test.go | 16 ++++++++++++++++ 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/mcp/content.go b/mcp/content.go index fd027cf8..04888dec 100644 --- a/mcp/content.go +++ b/mcp/content.go @@ -58,13 +58,25 @@ type ImageContent struct { } func (c *ImageContent) MarshalJSON() ([]byte, error) { - return json.Marshal(&wireContent{ + // Custom wire format to ensure required fields are always included, even when empty. + data := c.Data + if data == nil { + data = []byte{} + } + wire := struct { + Type string `json:"type"` + MIMEType string `json:"mimeType"` + Data []byte `json:"data"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` + }{ Type: "image", MIMEType: c.MIMEType, - Data: c.Data, + Data: data, Meta: c.Meta, Annotations: c.Annotations, - }) + } + return json.Marshal(wire) } func (c *ImageContent) fromWire(wire *wireContent) { @@ -83,13 +95,25 @@ type AudioContent struct { } func (c AudioContent) MarshalJSON() ([]byte, error) { - return json.Marshal(&wireContent{ + // Custom wire format to ensure required fields are always included, even when empty. + data := c.Data + if data == nil { + data = []byte{} + } + wire := struct { + Type string `json:"type"` + MIMEType string `json:"mimeType"` + Data []byte `json:"data"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` + }{ Type: "audio", MIMEType: c.MIMEType, - Data: c.Data, + Data: data, Meta: c.Meta, Annotations: c.Annotations, - }) + } + return json.Marshal(wire) } func (c *AudioContent) fromWire(wire *wireContent) { diff --git a/mcp/content_test.go b/mcp/content_test.go index 7a549bea..9366b0d4 100644 --- a/mcp/content_test.go +++ b/mcp/content_test.go @@ -45,6 +45,14 @@ func TestContent(t *testing.T) { }, `{"type":"image","mimeType":"image/png","data":"YTFiMmMz"}`, }, + { + &mcp.ImageContent{MIMEType: "image/png", Data: []byte{}}, + `{"type":"image","mimeType":"image/png","data":""}`, + }, + { + &mcp.ImageContent{Data: []byte("test")}, + `{"type":"image","mimeType":"","data":"dGVzdA=="}`, + }, { &mcp.ImageContent{ Data: []byte("a1b2c3"), @@ -61,6 +69,14 @@ func TestContent(t *testing.T) { }, `{"type":"audio","mimeType":"audio/wav","data":"YTFiMmMz"}`, }, + { + &mcp.AudioContent{MIMEType: "audio/wav", Data: []byte{}}, + `{"type":"audio","mimeType":"audio/wav","data":""}`, + }, + { + &mcp.AudioContent{Data: []byte("test")}, + `{"type":"audio","mimeType":"","data":"dGVzdA=="}`, + }, { &mcp.AudioContent{ Data: []byte("a1b2c3"), From 3bbe74f8a512c5b0a9b015a2ddc205a2dccc5ec3 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 21 Jul 2025 10:41:00 -0400 Subject: [PATCH 030/221] jsonschema: make schema state independent of resolve arguments (#85) Move schema state that depends on resolution into the Resolved struct. The remaining unexported schema fields are dependent only on the schema itself and its sub-schemas. Now a schema can be resolved multiple times, as itself or as part of other schemas. Fixes #84. --- jsonschema/resolve.go | 177 ++++++++++++++++++++++++++----------- jsonschema/resolve_test.go | 24 +++-- jsonschema/schema.go | 62 +------------ jsonschema/validate.go | 44 +++++---- 4 files changed, 171 insertions(+), 136 deletions(-) diff --git a/jsonschema/resolve.go b/jsonschema/resolve.go index 58a44b2b..cc551e79 100644 --- a/jsonschema/resolve.go +++ b/jsonschema/resolve.go @@ -25,12 +25,77 @@ type Resolved struct { root *Schema // map from $ids to their schemas resolvedURIs map[string]*Schema + // map from schemas to additional info computed during resolution + resolvedInfos map[*Schema]*resolvedInfo +} + +func newResolved(s *Schema) *Resolved { + return &Resolved{ + root: s, + resolvedURIs: map[string]*Schema{}, + resolvedInfos: map[*Schema]*resolvedInfo{}, + } +} + +// resolvedInfo holds information specific to a schema that is computed by [Schema.Resolve]. +type resolvedInfo struct { + s *Schema + // The JSON Pointer path from the root schema to here. + // Used in errors. + path string + // The schema's base schema. + // If the schema is the root or has an ID, its base is itself. + // Otherwise, its base is the innermost enclosing schema whose base + // is itself. + // Intuitively, a base schema is one that can be referred to with a + // fragmentless URI. + base *Schema + // The URI for the schema, if it is the root or has an ID. + // Otherwise nil. + // Invariants: + // s.base.uri != nil. + // s.base == s <=> s.uri != nil + uri *url.URL + // The schema to which Ref refers. + resolvedRef *Schema + + // If the schema has a dynamic ref, exactly one of the next two fields + // will be non-zero after successful resolution. + // The schema to which the dynamic ref refers when it acts lexically. + resolvedDynamicRef *Schema + // The anchor to look up on the stack when the dynamic ref acts dynamically. + dynamicRefAnchor string + + // The following fields are independent of arguments to Schema.Resolved, + // so they could live on the Schema. We put them here for simplicity. + + // The set of required properties. + isRequired map[string]bool + + // Compiled regexps. + pattern *regexp.Regexp + patternProperties map[*regexp.Regexp]*Schema + + // Map from anchors to subschemas. + anchors map[string]anchorInfo } // Schema returns the schema that was resolved. // It must not be modified. func (r *Resolved) Schema() *Schema { return r.root } +// schemaString returns a short string describing the schema. +func (r *Resolved) schemaString(s *Schema) string { + if s.ID != "" { + return s.ID + } + info := r.resolvedInfos[s] + if info.path != "" { + return info.path + } + return "" +} + // A Loader reads and unmarshals the schema at uri, if any. type Loader func(uri *url.URL) (*Schema, error) @@ -59,6 +124,8 @@ type ResolveOptions struct { // Resolve resolves all references within the schema and performs other tasks that // prepare the schema for validation. // If opts is nil, the default values are used. +// The schema must not be changed after Resolve is called. +// The same schema may be resolved multiple times. func (root *Schema) Resolve(opts *ResolveOptions) (*Resolved, error) { // There are up to five steps required to prepare a schema to validate. // 1. Load: read the schema from somewhere and unmarshal it. @@ -71,9 +138,6 @@ func (root *Schema) Resolve(opts *ResolveOptions) (*Resolved, error) { // in a map from URIs to schemas within root. // 4. Resolve references: all refs in the schemas are replaced with the schema they refer to. // 5. (Optional.) If opts.ValidateDefaults is true, validate the defaults. - if root.path != "" { - return nil, fmt.Errorf("jsonschema: Resolve: %s already resolved", root) - } r := &resolver{loaded: map[string]*Resolved{}} if opts != nil { r.opts = *opts @@ -121,20 +185,21 @@ func (r *resolver) resolve(s *Schema, baseURI *url.URL) (*Resolved, error) { if baseURI.Fragment != "" { return nil, fmt.Errorf("base URI %s must not have a fragment", baseURI) } - if err := s.check(); err != nil { + rs := newResolved(s) + + if err := s.check(rs.resolvedInfos); err != nil { return nil, err } - m, err := resolveURIs(s, baseURI) - if err != nil { + if err := resolveURIs(rs, baseURI); err != nil { return nil, err } - rs := &Resolved{root: s, resolvedURIs: m} + // Remember the schema by both the URI we loaded it from and its canonical name, // which may differ if the schema has an $id. // We must set the map before calling resolveRefs, or ref cycles will cause unbounded recursion. r.loaded[baseURI.String()] = rs - r.loaded[s.uri.String()] = rs + r.loaded[rs.resolvedInfos[s].uri.String()] = rs if err := r.resolveRefs(rs); err != nil { return nil, err @@ -142,10 +207,10 @@ func (r *resolver) resolve(s *Schema, baseURI *url.URL) (*Resolved, error) { return rs, nil } -func (root *Schema) check() error { +func (root *Schema) check(infos map[*Schema]*resolvedInfo) error { // Check for structural validity. Do this first and fail fast: // bad structure will cause other code to panic. - if err := root.checkStructure(); err != nil { + if err := root.checkStructure(infos); err != nil { return err } @@ -153,14 +218,16 @@ func (root *Schema) check() error { report := func(err error) { errs = append(errs, err) } for ss := range root.all() { - ss.checkLocal(report) + ss.checkLocal(report, infos) } return errors.Join(errs...) } // checkStructure verifies that root and its subschemas form a tree. // It also assigns each schema a unique path, to improve error messages. -func (root *Schema) checkStructure() error { +func (root *Schema) checkStructure(infos map[*Schema]*resolvedInfo) error { + assert(len(infos) == 0, "non-empty infos") + var check func(reflect.Value, []byte) error check = func(v reflect.Value, path []byte) error { // For the purpose of error messages, the root schema has path "root" @@ -173,16 +240,15 @@ func (root *Schema) checkStructure() error { if s == nil { return fmt.Errorf("jsonschema: schema at %s is nil", p) } - if s.path != "" { + if info, ok := infos[s]; ok { // We've seen s before. // The schema graph at root is not a tree, but it needs to - // be because we assume a unique parent when we store a schema's base - // in the Schema. A cycle would also put Schema.all into an infinite - // recursion. + // be because a schema's base must be unique. + // A cycle would also put Schema.all into an infinite recursion. return fmt.Errorf("jsonschema: schemas at %s do not form a tree; %s appears more than once (also at %s)", - root, s.path, p) + root, info.path, p) } - s.path = p + infos[s] = &resolvedInfo{s: s, path: p} for _, info := range schemaFieldInfos { fv := v.Elem().FieldByIndex(info.sf.Index) @@ -224,7 +290,7 @@ func (root *Schema) checkStructure() error { // Since checking a regexp involves compiling it, checkLocal saves those compiled regexps // in the schema for later use. // It appends the errors it finds to errs. -func (s *Schema) checkLocal(report func(error)) { +func (s *Schema) checkLocal(report func(error), infos map[*Schema]*resolvedInfo) { addf := func(format string, args ...any) { msg := fmt.Sprintf(format, args...) report(fmt.Errorf("jsonschema.Schema: %s: %s", s, msg)) @@ -250,33 +316,35 @@ func (s *Schema) checkLocal(report func(error)) { addf("cannot validate a schema with $vocabulary") } + info := infos[s] + // Check and compile regexps. if s.Pattern != "" { re, err := regexp.Compile(s.Pattern) if err != nil { addf("pattern: %v", err) } else { - s.pattern = re + info.pattern = re } } if len(s.PatternProperties) > 0 { - s.patternProperties = map[*regexp.Regexp]*Schema{} + info.patternProperties = map[*regexp.Regexp]*Schema{} for reString, subschema := range s.PatternProperties { re, err := regexp.Compile(reString) if err != nil { addf("patternProperties[%q]: %v", reString, err) continue } - s.patternProperties[re] = subschema + info.patternProperties[re] = subschema } } // Build a set of required properties, to avoid quadratic behavior when validating // a struct. if len(s.Required) > 0 { - s.isRequired = map[string]bool{} + info.isRequired = map[string]bool{} for _, r := range s.Required { - s.isRequired[r] = true + info.isRequired[r] = true } } } @@ -285,8 +353,6 @@ func (s *Schema) checkLocal(report func(error)) { // to baseURI. // See https://json-schema.org/draft/2020-12/json-schema-core#section-8.2, section // 8.2.1. - -// TODO(jba): dynamicAnchors (§8.2.2) // // Every schema has a base URI and a parent base URI. // @@ -316,11 +382,12 @@ func (s *Schema) checkLocal(report func(error)) { // allOf/1 http://b.com (absolute $id; doesn't matter that it's not under the loaded URI) // allOf/2 http://a.com/root.json (inherited from parent) // allOf/2/not http://a.com/root.json (inherited from parent) -func resolveURIs(root *Schema, baseURI *url.URL) (map[string]*Schema, error) { - resolvedURIs := map[string]*Schema{} - +func resolveURIs(rs *Resolved, baseURI *url.URL) error { var resolve func(s, base *Schema) error resolve = func(s, base *Schema) error { + info := rs.resolvedInfos[s] + baseInfo := rs.resolvedInfos[base] + // ids are scoped to the root. if s.ID != "" { // A non-empty ID establishes a new base. @@ -332,26 +399,27 @@ func resolveURIs(root *Schema, baseURI *url.URL) (map[string]*Schema, error) { return fmt.Errorf("$id %s must not have a fragment", s.ID) } // The base URI for this schema is its $id resolved against the parent base. - s.uri = base.uri.ResolveReference(idURI) - if !s.uri.IsAbs() { - return fmt.Errorf("$id %s does not resolve to an absolute URI (base is %s)", s.ID, s.base.uri) + info.uri = baseInfo.uri.ResolveReference(idURI) + if !info.uri.IsAbs() { + return fmt.Errorf("$id %s does not resolve to an absolute URI (base is %q)", s.ID, baseInfo.uri) } - resolvedURIs[s.uri.String()] = s + rs.resolvedURIs[info.uri.String()] = s base = s // needed for anchors + baseInfo = rs.resolvedInfos[base] } - s.base = base + info.base = base // Anchors and dynamic anchors are URI fragments that are scoped to their base. // We treat them as keys in a map stored within the schema. setAnchor := func(anchor string, dynamic bool) error { if anchor != "" { - if _, ok := base.anchors[anchor]; ok { - return fmt.Errorf("duplicate anchor %q in %s", anchor, base.uri) + if _, ok := baseInfo.anchors[anchor]; ok { + return fmt.Errorf("duplicate anchor %q in %s", anchor, baseInfo.uri) } - if base.anchors == nil { - base.anchors = map[string]anchorInfo{} + if baseInfo.anchors == nil { + baseInfo.anchors = map[string]anchorInfo{} } - base.anchors[anchor] = anchorInfo{s, dynamic} + baseInfo.anchors[anchor] = anchorInfo{s, dynamic} } return nil } @@ -368,13 +436,11 @@ func resolveURIs(root *Schema, baseURI *url.URL) (map[string]*Schema, error) { } // Set the root URI to the base for now. If the root has an $id, this will change. - root.uri = baseURI + rs.resolvedInfos[rs.root].uri = baseURI // The original base, even if changed, is still a valid way to refer to the root. - resolvedURIs[baseURI.String()] = root - if err := resolve(root, root); err != nil { - return nil, err - } - return resolvedURIs, nil + rs.resolvedURIs[baseURI.String()] = rs.root + + return resolve(rs.root, rs.root) } // resolveRefs replaces every ref in the schemas with the schema it refers to. @@ -382,6 +448,7 @@ func resolveURIs(root *Schema, baseURI *url.URL) (map[string]*Schema, error) { // that needs to be loaded. func (r *resolver) resolveRefs(rs *Resolved) error { for s := range rs.root.all() { + info := rs.resolvedInfos[s] if s.Ref != "" { refSchema, _, err := r.resolveRef(rs, s, s.Ref) if err != nil { @@ -389,7 +456,7 @@ func (r *resolver) resolveRefs(rs *Resolved) error { } // Whether or not the anchor referred to by $ref fragment is dynamic, // the ref still treats it lexically. - s.resolvedRef = refSchema + info.resolvedRef = refSchema } if s.DynamicRef != "" { refSchema, frag, err := r.resolveRef(rs, s, s.DynamicRef) @@ -399,11 +466,11 @@ func (r *resolver) resolveRefs(rs *Resolved) error { if frag != "" { // The dynamic ref's fragment points to a dynamic anchor. // We must resolve the fragment at validation time. - s.dynamicRefAnchor = frag + info.dynamicRefAnchor = frag } else { // There is no dynamic anchor in the lexically referenced schema, // so the dynamic ref behaves like a lexical ref. - s.resolvedDynamicRef = refSchema + info.resolvedDynamicRef = refSchema } } } @@ -417,7 +484,8 @@ func (r *resolver) resolveRef(rs *Resolved, s *Schema, ref string) (_ *Schema, d return nil, "", err } // URI-resolve the ref against the current base URI to get a complete URI. - refURI = s.base.uri.ResolveReference(refURI) + base := rs.resolvedInfos[s].base + refURI = rs.resolvedInfos[base].uri.ResolveReference(refURI) // The non-fragment part of a ref URI refers to the base URI of some schema. // This part is the same for dynamic refs too: their non-fragment part resolves // lexically. @@ -447,6 +515,13 @@ func (r *resolver) resolveRef(rs *Resolved, s *Schema, ref string) (_ *Schema, d } referencedSchema = lrs.root assert(referencedSchema != nil, "nil referenced schema") + // Copy the resolvedInfos from lrs into rs, without overwriting + // (hence we can't use maps.Insert). + for s, i := range lrs.resolvedInfos { + if rs.resolvedInfos[s] == nil { + rs.resolvedInfos[s] = i + } + } } } @@ -456,7 +531,9 @@ func (r *resolver) resolveRef(rs *Resolved, s *Schema, ref string) (_ *Schema, d // A JSON Pointer is either the empty string or begins with a '/', // whereas anchors are always non-empty strings that don't contain slashes. if frag != "" && !strings.HasPrefix(frag, "/") { - info, found := referencedSchema.anchors[frag] + resInfo := rs.resolvedInfos[referencedSchema] + info, found := resInfo.anchors[frag] + if !found { return nil, "", fmt.Errorf("no anchor %q in %s", frag, s) } diff --git a/jsonschema/resolve_test.go b/jsonschema/resolve_test.go index 1b176bfa..36aa424b 100644 --- a/jsonschema/resolve_test.go +++ b/jsonschema/resolve_test.go @@ -17,7 +17,8 @@ import ( func TestSchemaStructure(t *testing.T) { check := func(s *Schema, want string) { t.Helper() - err := s.checkStructure() + infos := map[*Schema]*resolvedInfo{} + err := s.checkStructure(infos) if err == nil || !strings.Contains(err.Error(), want) { t.Errorf("checkStructure returned error %q, want %q", err, want) } @@ -89,13 +90,14 @@ func TestPaths(t *testing.T) { {root.PrefixItems[1], "/prefixItems/1"}, {root.PrefixItems[1].Items, "/prefixItems/1/items"}, } - if err := root.checkStructure(); err != nil { + rs := newResolved(root) + if err := root.checkStructure(rs.resolvedInfos); err != nil { t.Fatal(err) } var got []item for s := range root.all() { - got = append(got, item{s, s.path}) + got = append(got, item{s, rs.resolvedInfos[s].path}) } if !slices.Equal(got, want) { t.Errorf("\ngot %v\nwant %v", got, want) @@ -129,8 +131,12 @@ func TestResolveURIs(t *testing.T) { if err != nil { t.Fatal(err) } - got, err := resolveURIs(root, base) - if err != nil { + + rs := newResolved(root) + if err := root.check(rs.resolvedInfos); err != nil { + t.Fatal(err) + } + if err := resolveURIs(rs, base); err != nil { t.Fatal(err) } @@ -154,6 +160,7 @@ func TestResolveURIs(t *testing.T) { }, } + got := rs.resolvedURIs gotKeys := slices.Sorted(maps.Keys(got)) wantKeys := slices.Sorted(maps.Keys(wantIDs)) if !slices.Equal(gotKeys, wantKeys) { @@ -163,11 +170,12 @@ func TestResolveURIs(t *testing.T) { t.Errorf("IDs:\ngot %+v\n\nwant %+v", got, wantIDs) } for s := range root.all() { + info := rs.resolvedInfos[s] if want := wantAnchors[s]; want != nil { - if got := s.anchors; !maps.Equal(got, want) { + if got := info.anchors; !maps.Equal(got, want) { t.Errorf("anchors:\ngot %+v\n\nwant %+v", got, want) } - } else if s.anchors != nil { + } else if info.anchors != nil { t.Errorf("non-nil anchors for %s", s) } } @@ -199,7 +207,7 @@ func TestRefCycle(t *testing.T) { check := func(s *Schema, key string) { t.Helper() - if s.resolvedRef != schemas[key] { + if rs.resolvedInfos[s].resolvedRef != schemas[key] { t.Errorf("%s resolvedRef != schemas[%q]", s.json(), key) } } diff --git a/jsonschema/schema.go b/jsonschema/schema.go index 26623f1b..4b1d6eed 100644 --- a/jsonschema/schema.go +++ b/jsonschema/schema.go @@ -13,9 +13,7 @@ import ( "iter" "maps" "math" - "net/url" "reflect" - "regexp" "slices" "github.com/modelcontextprotocol/go-sdk/internal/util" @@ -129,47 +127,6 @@ type Schema struct { // Extra allows for additional keywords beyond those specified. Extra map[string]any `json:"-"` - - // computed fields - - // This schema's base schema. - // If the schema is the root or has an ID, its base is itself. - // Otherwise, its base is the innermost enclosing schema whose base - // is itself. - // Intuitively, a base schema is one that can be referred to with a - // fragmentless URI. - base *Schema - - // The URI for the schema, if it is the root or has an ID. - // Otherwise nil. - // Invariants: - // s.base.uri != nil. - // s.base == s <=> s.uri != nil - uri *url.URL - - // The JSON Pointer path from the root schema to here. - // Used in errors. - path string - - // The schema to which Ref refers. - resolvedRef *Schema - - // If the schema has a dynamic ref, exactly one of the next two fields - // will be non-zero after successful resolution. - // The schema to which the dynamic ref refers when it acts lexically. - resolvedDynamicRef *Schema - // The anchor to look up on the stack when the dynamic ref acts dynamically. - dynamicRefAnchor string - - // Map from anchors to subschemas. - anchors map[string]anchorInfo - - // compiled regexps - pattern *regexp.Regexp - patternProperties map[*regexp.Regexp]*Schema - - // the set of required properties - isRequired map[string]bool } // falseSchema returns a new Schema tree that fails to validate any value. @@ -186,28 +143,15 @@ type anchorInfo struct { // String returns a short description of the schema. func (s *Schema) String() string { - if s.uri != nil { - if u := s.uri.String(); u != "" { - return u - } + if s.ID != "" { + return s.ID } if a := cmp.Or(s.Anchor, s.DynamicAnchor); a != "" { - return fmt.Sprintf("%q, anchor %s", s.base.uri.String(), a) - } - if s.path != "" { - return s.path + return fmt.Sprintf("anchor %s", a) } return "" } -// ResolvedRef returns the Schema to which this schema's $ref keyword -// refers, or nil if it doesn't have a $ref. -// It returns nil if this schema has not been resolved, meaning that -// [Schema.Resolve] was called on it or one of its ancestors. -func (s *Schema) ResolvedRef() *Schema { - return s.resolvedRef -} - func (s *Schema) basicChecks() error { if s.Type != "" && s.Types != nil { return errors.New("both Type and Types are set; at most one should be") diff --git a/jsonschema/validate.go b/jsonschema/validate.go index a04e42bd..3b864107 100644 --- a/jsonschema/validate.go +++ b/jsonschema/validate.go @@ -48,12 +48,12 @@ func (rs *Resolved) validateDefaults() error { // We checked for nil schemas in [Schema.Resolve]. assert(s != nil, "nil schema") if s.DynamicRef != "" { - return fmt.Errorf("jsonschema: %s: validateDefaults does not support dynamic refs", s) + return fmt.Errorf("jsonschema: %s: validateDefaults does not support dynamic refs", rs.schemaString(s)) } if s.Default != nil { var d any if err := json.Unmarshal(s.Default, &d); err != nil { - return fmt.Errorf("unmarshaling default value of schema %s: %w", s, err) + return fmt.Errorf("unmarshaling default value of schema %s: %w", rs.schemaString(s), err) } if err := st.validate(reflect.ValueOf(d), s, nil); err != nil { return err @@ -74,7 +74,7 @@ type state struct { // validate validates the reflected value of the instance. func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *annotations) (err error) { - defer util.Wrapf(&err, "validating %s", schema) + defer util.Wrapf(&err, "validating %s", st.rs.schemaString(schema)) // Maintain a stack for dynamic schema resolution. st.stack = append(st.stack, schema) // push @@ -90,6 +90,8 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an instance = instance.Elem() } + schemaInfo := st.rs.resolvedInfos[schema] + // type: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.1 if schema.Type != "" || schema.Types != nil { gotType, ok := jsonType(instance) @@ -176,7 +178,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an } } - if schema.Pattern != "" && !schema.pattern.MatchString(str) { + if schema.Pattern != "" && !schemaInfo.pattern.MatchString(str) { return fmt.Errorf("pattern: %q does not match regular expression %q", str, schema.Pattern) } } @@ -185,7 +187,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // $ref: https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.1 if schema.Ref != "" { - if err := st.validate(instance, schema.resolvedRef, &anns); err != nil { + if err := st.validate(instance, schemaInfo.resolvedRef, &anns); err != nil { return err } } @@ -193,11 +195,11 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // $dynamicRef: https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.2 if schema.DynamicRef != "" { // The ref behaves lexically or dynamically, but not both. - assert((schema.resolvedDynamicRef == nil) != (schema.dynamicRefAnchor == ""), + assert((schemaInfo.resolvedDynamicRef == nil) != (schemaInfo.dynamicRefAnchor == ""), "DynamicRef not resolved properly") - if schema.resolvedDynamicRef != nil { + if schemaInfo.resolvedDynamicRef != nil { // Same as $ref. - if err := st.validate(instance, schema.resolvedDynamicRef, &anns); err != nil { + if err := st.validate(instance, schemaInfo.resolvedDynamicRef, &anns); err != nil { return err } } else { @@ -212,14 +214,15 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // For an example, search for "detached" in testdata/draft2020-12/dynamicRef.json. var dynamicSchema *Schema for _, s := range st.stack { - info, ok := s.base.anchors[schema.dynamicRefAnchor] + base := st.rs.resolvedInfos[s].base + info, ok := st.rs.resolvedInfos[base].anchors[schemaInfo.dynamicRefAnchor] if ok && info.dynamic { dynamicSchema = info.schema break } } if dynamicSchema == nil { - return fmt.Errorf("missing dynamic anchor %q", schema.dynamicRefAnchor) + return fmt.Errorf("missing dynamic anchor %q", schemaInfo.dynamicRefAnchor) } if err := st.validate(instance, dynamicSchema, &anns); err != nil { return err @@ -417,7 +420,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // If the instance is a struct and an optional property has the zero // value, then we could interpret it as present or missing. Be generous: // assume it's missing, and thus always validates successfully. - if instance.Kind() == reflect.Struct && val.IsZero() && !schema.isRequired[prop] { + if instance.Kind() == reflect.Struct && val.IsZero() && !schemaInfo.isRequired[prop] { continue } if err := st.validate(val, subschema, nil); err != nil { @@ -428,7 +431,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an if len(schema.PatternProperties) > 0 { for prop, val := range properties(instance) { // Check every matching pattern. - for re, schema := range schema.patternProperties { + for re, schema := range schemaInfo.patternProperties { if re.MatchString(prop) { if err := st.validate(val, schema, nil); err != nil { return err @@ -463,7 +466,7 @@ func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *an // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.5 var min, max int if schema.MinProperties != nil || schema.MaxProperties != nil { - min, max = numPropertiesBounds(instance, schema.isRequired) + min, max = numPropertiesBounds(instance, schemaInfo.isRequired) } if schema.MinProperties != nil { if n, m := max, *schema.MinProperties; n < m { @@ -554,10 +557,11 @@ func (st *state) resolveDynamicRef(schema *Schema) (*Schema, error) { if schema.DynamicRef == "" { return nil, nil } + info := st.rs.resolvedInfos[schema] // The ref behaves lexically or dynamically, but not both. - assert((schema.resolvedDynamicRef == nil) != (schema.dynamicRefAnchor == ""), + assert((info.resolvedDynamicRef == nil) != (info.dynamicRefAnchor == ""), "DynamicRef not statically resolved properly") - if r := schema.resolvedDynamicRef; r != nil { + if r := info.resolvedDynamicRef; r != nil { // Same as $ref. return r, nil } @@ -571,12 +575,13 @@ func (st *state) resolveDynamicRef(schema *Schema) (*Schema, error) { // on the stack. // For an example, search for "detached" in testdata/draft2020-12/dynamicRef.json. for _, s := range st.stack { - info, ok := s.base.anchors[schema.dynamicRefAnchor] + base := st.rs.resolvedInfos[s].base + info, ok := st.rs.resolvedInfos[base].anchors[info.dynamicRefAnchor] if ok && info.dynamic { return info.schema, nil } } - return nil, fmt.Errorf("missing dynamic anchor %q", schema.dynamicRefAnchor) + return nil, fmt.Errorf("missing dynamic anchor %q", info.dynamicRefAnchor) } // ApplyDefaults modifies an instance by applying the schema's defaults to it. If @@ -608,8 +613,9 @@ func (rs *Resolved) ApplyDefaults(instancep any) error { // Leave this as a potentially recursive helper function, because we'll surely want // to apply defaults on sub-schemas someday. func (st *state) applyDefaults(instancep reflect.Value, schema *Schema) (err error) { - defer util.Wrapf(&err, "applyDefaults: schema %s, instance %v", schema, instancep) + defer util.Wrapf(&err, "applyDefaults: schema %s, instance %v", st.rs.schemaString(schema), instancep) + schemaInfo := st.rs.resolvedInfos[schema] instance := instancep.Elem() if instance.Kind() == reflect.Map || instance.Kind() == reflect.Struct { if instance.Kind() == reflect.Map { @@ -619,7 +625,7 @@ func (st *state) applyDefaults(instancep reflect.Value, schema *Schema) (err err } for prop, subschema := range schema.Properties { // Ignore defaults on required properties. (A required property shouldn't have a default.) - if schema.isRequired[prop] { + if schemaInfo.isRequired[prop] { continue } val := property(instance, prop) From a519c182e8118e1d6c5e812ebba918598ccc7935 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 21 Jul 2025 10:54:59 -0400 Subject: [PATCH 031/221] mcp: export Event (#145) Export the Event type in preparation for providing user-definable storage for resumable streams. Also, move event code to a separate file. This PR is just renaming and code motion. --- mcp/event.go | 138 ++++++++++++++++++++++++++++++++++++++++ mcp/event_test.go | 99 +++++++++++++++++++++++++++++ mcp/sse.go | 139 +++-------------------------------------- mcp/sse_test.go | 90 -------------------------- mcp/streamable.go | 15 +++-- mcp/streamable_test.go | 2 +- 6 files changed, 256 insertions(+), 227 deletions(-) create mode 100644 mcp/event.go create mode 100644 mcp/event_test.go diff --git a/mcp/event.go b/mcp/event.go new file mode 100644 index 00000000..4562b811 --- /dev/null +++ b/mcp/event.go @@ -0,0 +1,138 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file is for SSE events. +// See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events. + +package mcp + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "iter" + "net/http" + "strings" +) + +// An Event is a server-sent event. +// See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#fields. +type Event struct { + Name string // the "event" field + ID string // the "id" field + Data []byte // the "data" field +} + +// Empty reports whether the Event is empty. +func (e Event) Empty() bool { + return e.Name == "" && e.ID == "" && len(e.Data) == 0 +} + +// writeEvent writes the event to w, and flushes. +func writeEvent(w io.Writer, evt Event) (int, error) { + var b bytes.Buffer + if evt.Name != "" { + fmt.Fprintf(&b, "event: %s\n", evt.Name) + } + if evt.ID != "" { + fmt.Fprintf(&b, "id: %s\n", evt.ID) + } + fmt.Fprintf(&b, "data: %s\n\n", string(evt.Data)) + n, err := w.Write(b.Bytes()) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return n, err +} + +// scanEvents iterates SSE events in the given scanner. The iterated error is +// terminal: if encountered, the stream is corrupt or broken and should no +// longer be used. +// +// TODO(rfindley): consider a different API here that makes failure modes more +// apparent. +func scanEvents(r io.Reader) iter.Seq2[Event, error] { + scanner := bufio.NewScanner(r) + const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size + scanner.Buffer(nil, maxTokenSize) + + // TODO: investigate proper behavior when events are out of order, or have + // non-standard names. + var ( + eventKey = []byte("event") + idKey = []byte("id") + dataKey = []byte("data") + ) + + return func(yield func(Event, error) bool) { + // iterate event from the wire. + // https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#examples + // + // - `key: value` line records. + // - Consecutive `data: ...` fields are joined with newlines. + // - Unrecognized fields are ignored. Since we only care about 'event', 'id', and + // 'data', these are the only three we consider. + // - Lines starting with ":" are ignored. + // - Records are terminated with two consecutive newlines. + var ( + evt Event + dataBuf *bytes.Buffer // if non-nil, preceding field was also data + ) + flushData := func() { + if dataBuf != nil { + evt.Data = dataBuf.Bytes() + dataBuf = nil + } + } + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + flushData() + // \n\n is the record delimiter + if !evt.Empty() && !yield(evt, nil) { + return + } + evt = Event{} + continue + } + before, after, found := bytes.Cut(line, []byte{':'}) + if !found { + yield(Event{}, fmt.Errorf("malformed line in SSE stream: %q", string(line))) + return + } + if !bytes.Equal(before, dataKey) { + flushData() + } + switch { + case bytes.Equal(before, eventKey): + evt.Name = strings.TrimSpace(string(after)) + case bytes.Equal(before, idKey): + evt.ID = strings.TrimSpace(string(after)) + case bytes.Equal(before, dataKey): + data := bytes.TrimSpace(after) + if dataBuf != nil { + dataBuf.WriteByte('\n') + dataBuf.Write(data) + } else { + dataBuf = new(bytes.Buffer) + dataBuf.Write(data) + } + } + } + if err := scanner.Err(); err != nil { + if errors.Is(err, bufio.ErrTooLong) { + err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize) + } + if !yield(Event{}, err) { + return + } + } + flushData() + if !evt.Empty() { + yield(evt, nil) + } + } +} diff --git a/mcp/event_test.go b/mcp/event_test.go new file mode 100644 index 00000000..2e2e3d26 --- /dev/null +++ b/mcp/event_test.go @@ -0,0 +1,99 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "strings" + "testing" +) + +func TestScanEvents(t *testing.T) { + tests := []struct { + name string + input string + want []Event + wantErr string + }{ + { + name: "simple event", + input: "event: message\nid: 1\ndata: hello\n\n", + want: []Event{ + {Name: "message", ID: "1", Data: []byte("hello")}, + }, + }, + { + name: "multiple data lines", + input: "data: line 1\ndata: line 2\n\n", + want: []Event{ + {Data: []byte("line 1\nline 2")}, + }, + }, + { + name: "multiple events", + input: "data: first\n\nevent: second\ndata: second\n\n", + want: []Event{ + {Data: []byte("first")}, + {Name: "second", Data: []byte("second")}, + }, + }, + { + name: "no trailing newline", + input: "data: hello", + want: []Event{ + {Data: []byte("hello")}, + }, + }, + { + name: "malformed line", + input: "invalid line\n\n", + wantErr: "malformed line", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := strings.NewReader(tt.input) + var got []Event + var err error + for e, err2 := range scanEvents(r) { + if err2 != nil { + err = err2 + break + } + got = append(got, e) + } + + if tt.wantErr != "" { + if err == nil { + t.Fatalf("scanEvents() got nil error, want error containing %q", tt.wantErr) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("scanEvents() error = %q, want containing %q", err, tt.wantErr) + } + return + } + + if err != nil { + t.Fatalf("scanEvents() returned unexpected error: %v", err) + } + + if len(got) != len(tt.want) { + t.Fatalf("scanEvents() got %d events, want %d", len(got), len(tt.want)) + } + + for i := range got { + if g, w := got[i].Name, tt.want[i].Name; g != w { + t.Errorf("event %d: name = %q, want %q", i, g, w) + } + if g, w := got[i].ID, tt.want[i].ID; g != w { + t.Errorf("event %d: id = %q, want %q", i, g, w) + } + if g, w := string(got[i].Data), string(tt.want[i].Data); g != w { + t.Errorf("event %d: data = %q, want %q", i, g, w) + } + } + }) + } +} diff --git a/mcp/sse.go b/mcp/sse.go index f0d7b34c..4051f45c 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -5,16 +5,12 @@ package mcp import ( - "bufio" "bytes" "context" - "errors" "fmt" "io" - "iter" "net/http" "net/url" - "strings" "sync" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" @@ -41,34 +37,6 @@ import ( // - Read reads off a message queue that is pushed to via POST requests. // - Close causes the hanging GET to exit. -// An event is a server-sent event. -type event struct { - name string - id string - data []byte -} - -func (e event) empty() bool { - return e.name == "" && e.id == "" && len(e.data) == 0 -} - -// writeEvent writes the event to w, and flushes. -func writeEvent(w io.Writer, evt event) (int, error) { - var b bytes.Buffer - if evt.name != "" { - fmt.Fprintf(&b, "event: %s\n", evt.name) - } - if evt.id != "" { - fmt.Fprintf(&b, "id: %s\n", evt.id) - } - fmt.Fprintf(&b, "data: %s\n\n", string(evt.data)) - n, err := w.Write(b.Bytes()) - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - return n, err -} - // SSEHandler is an http.Handler that serves SSE-based MCP sessions as defined by // the [2024-11-05 version] of the MCP spec. // @@ -172,9 +140,9 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) // See [SSEServerTransport] for more details on the [Connection] implementation. func (t *SSEServerTransport) Connect(context.Context) (Connection, error) { t.mu.Lock() - _, err := writeEvent(t.w, event{ - name: "endpoint", - data: []byte(t.endpoint), + _, err := writeEvent(t.w, Event{ + Name: "endpoint", + Data: []byte(t.endpoint), }) t.mu.Unlock() if err != nil { @@ -300,7 +268,7 @@ func (s sseServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { return io.EOF } - _, err = writeEvent(s.t.w, event{name: "message", data: data}) + _, err = writeEvent(s.t.w, Event{Name: "message", Data: data}) return err } @@ -372,17 +340,17 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { } msgEndpoint, err := func() (*url.URL, error) { - var evt event + var evt Event for evt, err = range scanEvents(resp.Body) { break } if err != nil { return nil, err } - if evt.name != "endpoint" { - return nil, fmt.Errorf("first event is %q, want %q", evt.name, "endpoint") + if evt.Name != "endpoint" { + return nil, fmt.Errorf("first event is %q, want %q", evt.Name, "endpoint") } - raw := string(evt.data) + raw := string(evt.Data) return c.sseEndpoint.Parse(raw) }() if err != nil { @@ -408,7 +376,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { return } select { - case s.incoming <- evt.data: + case s.incoming <- evt.Data: case <-s.done: return } @@ -418,95 +386,6 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { return s, nil } -// scanEvents iterates SSE events in the given scanner. The iterated error is -// terminal: if encountered, the stream is corrupt or broken and should no -// longer be used. -// -// TODO(rfindley): consider a different API here that makes failure modes more -// apparent. -func scanEvents(r io.Reader) iter.Seq2[event, error] { - scanner := bufio.NewScanner(r) - const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size - scanner.Buffer(nil, maxTokenSize) - - // TODO: investigate proper behavior when events are out of order, or have - // non-standard names. - var ( - eventKey = []byte("event") - idKey = []byte("id") - dataKey = []byte("data") - ) - - return func(yield func(event, error) bool) { - // iterate event from the wire. - // https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#examples - // - // - `key: value` line records. - // - Consecutive `data: ...` fields are joined with newlines. - // - Unrecognized fields are ignored. Since we only care about 'event', 'id', and - // 'data', these are the only three we consider. - // - Lines starting with ":" are ignored. - // - Records are terminated with two consecutive newlines. - var ( - evt event - dataBuf *bytes.Buffer // if non-nil, preceding field was also data - ) - flushData := func() { - if dataBuf != nil { - evt.data = dataBuf.Bytes() - dataBuf = nil - } - } - for scanner.Scan() { - line := scanner.Bytes() - if len(line) == 0 { - flushData() - // \n\n is the record delimiter - if !evt.empty() && !yield(evt, nil) { - return - } - evt = event{} - continue - } - before, after, found := bytes.Cut(line, []byte{':'}) - if !found { - yield(event{}, fmt.Errorf("malformed line in SSE stream: %q", string(line))) - return - } - if !bytes.Equal(before, dataKey) { - flushData() - } - switch { - case bytes.Equal(before, eventKey): - evt.name = strings.TrimSpace(string(after)) - case bytes.Equal(before, idKey): - evt.id = strings.TrimSpace(string(after)) - case bytes.Equal(before, dataKey): - data := bytes.TrimSpace(after) - if dataBuf != nil { - dataBuf.WriteByte('\n') - dataBuf.Write(data) - } else { - dataBuf = new(bytes.Buffer) - dataBuf.Write(data) - } - } - } - if err := scanner.Err(); err != nil { - if errors.Is(err, bufio.ErrTooLong) { - err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize) - } - if !yield(event{}, err) { - return - } - } - flushData() - if !evt.empty() { - yield(evt, nil) - } - } -} - // An sseClientConn is a logical jsonrpc2 connection that implements the client // half of the SSE protocol: // - Writes are POSTS to the session endpoint. diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 846d68c0..b4e8ebad 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -9,7 +9,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "strings" "sync/atomic" "testing" @@ -90,95 +89,6 @@ func TestSSEServer(t *testing.T) { } } -func TestScanEvents(t *testing.T) { - tests := []struct { - name string - input string - want []event - wantErr string - }{ - { - name: "simple event", - input: "event: message\nid: 1\ndata: hello\n\n", - want: []event{ - {name: "message", id: "1", data: []byte("hello")}, - }, - }, - { - name: "multiple data lines", - input: "data: line 1\ndata: line 2\n\n", - want: []event{ - {data: []byte("line 1\nline 2")}, - }, - }, - { - name: "multiple events", - input: "data: first\n\nevent: second\ndata: second\n\n", - want: []event{ - {data: []byte("first")}, - {name: "second", data: []byte("second")}, - }, - }, - { - name: "no trailing newline", - input: "data: hello", - want: []event{ - {data: []byte("hello")}, - }, - }, - { - name: "malformed line", - input: "invalid line\n\n", - wantErr: "malformed line", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := strings.NewReader(tt.input) - var got []event - var err error - for e, err2 := range scanEvents(r) { - if err2 != nil { - err = err2 - break - } - got = append(got, e) - } - - if tt.wantErr != "" { - if err == nil { - t.Fatalf("scanEvents() got nil error, want error containing %q", tt.wantErr) - } - if !strings.Contains(err.Error(), tt.wantErr) { - t.Fatalf("scanEvents() error = %q, want containing %q", err, tt.wantErr) - } - return - } - - if err != nil { - t.Fatalf("scanEvents() returned unexpected error: %v", err) - } - - if len(got) != len(tt.want) { - t.Fatalf("scanEvents() got %d events, want %d", len(got), len(tt.want)) - } - - for i := range got { - if g, w := got[i].name, tt.want[i].name; g != w { - t.Errorf("event %d: name = %q, want %q", i, g, w) - } - if g, w := got[i].id, tt.want[i].id; g != w { - t.Errorf("event %d: id = %q, want %q", i, g, w) - } - if g, w := string(got[i].data), string(tt.want[i].data); g != w { - t.Errorf("event %d: data = %q, want %q", i, g, w) - } - } - }) - } -} - // roundTripperFunc is a helper to create a custom RoundTripper type roundTripperFunc func(*http.Request) (*http.Response, error) diff --git a/mcp/streamable.go b/mcp/streamable.go index d371c873..ef740d37 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -149,6 +149,9 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // NewStreamableServerTransport returns a new [StreamableServerTransport] with // the given session ID. +// The session ID must be globally unique, that is, different from any other +// session ID anywhere, past and future. (We recommend using a crypto random number +// generator to produce one, as in [crypto/rand.Text].) // // A StreamableServerTransport implements the server-side of the streamable // transport. @@ -246,7 +249,7 @@ type streamID int64 // a streamableMsg is an SSE event with an index into its logical stream. type streamableMsg struct { idx int - event event + event Event } // Connect implements the [Transport] interface. @@ -549,10 +552,10 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa idx := len(t.outgoingMessages[forConn]) t.outgoingMessages[forConn] = append(t.outgoingMessages[forConn], &streamableMsg{ idx: idx, - event: event{ - name: "message", - id: formatEventID(forConn, idx), - data: data, + event: Event{ + Name: "message", + ID: formatEventID(forConn, idx), + Data: data, }, }) if replyTo.IsValid() { @@ -768,7 +771,7 @@ func (s *streamableClientConn) handleSSE(resp *http.Response) { select { case <-s.done: return - case s.incoming <- evt.data: + case s.incoming <- evt.Data: } } }() diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index e0e85a1e..05bf59e7 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -520,7 +520,7 @@ func streamingRequest(ctx context.Context, serverURL, sessionID, method string, } // TODO(rfindley): do we need to check evt.name? // Does the MCP spec say anything about this? - msg, err := jsonrpc2.DecodeMessage(evt.data) + msg, err := jsonrpc2.DecodeMessage(evt.Data) if err != nil { return newSessionID, resp.StatusCode, fmt.Errorf("decoding message: %w", err) } From c6fabbb3d3ef3fa0e3f06175da2bc519c569af9c Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 21 Jul 2025 11:18:04 -0400 Subject: [PATCH 032/221] jsonschema: inference ignores invalid types (#147) Add ForLax[T], which ignores invalid types in schema inference instead of returning an error. This allows additional customization of a schema after inference does what it can. For example, an interface type where all the possible implementations are known can be described with "oneof". For #136. --- jsonschema/infer.go | 46 ++++++++- jsonschema/infer_test.go | 206 +++++++++++++++++++++++---------------- 2 files changed, 162 insertions(+), 90 deletions(-) diff --git a/jsonschema/infer.go b/jsonschema/infer.go index 654e6197..9ff0ddd5 100644 --- a/jsonschema/infer.go +++ b/jsonschema/infer.go @@ -47,7 +47,7 @@ import ( func For[T any]() (*Schema, error) { // TODO: consider skipping incompatible fields, instead of failing. seen := make(map[reflect.Type]bool) - s, err := forType(reflect.TypeFor[T](), seen) + s, err := forType(reflect.TypeFor[T](), seen, false) if err != nil { var z T return nil, fmt.Errorf("For[%T](): %w", z, err) @@ -55,7 +55,22 @@ func For[T any]() (*Schema, error) { return s, nil } -func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { +// ForLax behaves like [For], except that it ignores struct fields with invalid types instead of +// returning an error. That allows callers to adjust the resulting schema using custom knowledge. +// For example, an interface type where all the possible implementations are known +// can be described with "oneof". +func ForLax[T any]() (*Schema, error) { + // TODO: consider skipping incompatible fields, instead of failing. + seen := make(map[reflect.Type]bool) + s, err := forType(reflect.TypeFor[T](), seen, true) + if err != nil { + var z T + return nil, fmt.Errorf("ForLax[%T](): %w", z, err) + } + return s, nil +} + +func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, error) { // Follow pointers: the schema for *T is almost the same as for T, except that // an explicit JSON "null" is allowed for the pointer. allowNull := false @@ -96,20 +111,33 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { case reflect.Map: if t.Key().Kind() != reflect.String { + if lax { + return nil, nil // ignore + } return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind()) } + if t.Key().Kind() != reflect.String { + } s.Type = "object" - s.AdditionalProperties, err = forType(t.Elem(), seen) + s.AdditionalProperties, err = forType(t.Elem(), seen, lax) if err != nil { return nil, fmt.Errorf("computing map value schema: %v", err) } + if lax && s.AdditionalProperties == nil { + // Ignore if the element type is invalid. + return nil, nil + } case reflect.Slice, reflect.Array: s.Type = "array" - s.Items, err = forType(t.Elem(), seen) + s.Items, err = forType(t.Elem(), seen, lax) if err != nil { return nil, fmt.Errorf("computing element schema: %v", err) } + if lax && s.Items == nil { + // Ignore if the element type is invalid. + return nil, nil + } if t.Kind() == reflect.Array { s.MinItems = Ptr(t.Len()) s.MaxItems = Ptr(t.Len()) @@ -132,10 +160,14 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { if s.Properties == nil { s.Properties = make(map[string]*Schema) } - fs, err := forType(field.Type, seen) + fs, err := forType(field.Type, seen, lax) if err != nil { return nil, err } + if lax && fs == nil { + // Skip fields of invalid type. + continue + } if tag, ok := field.Tag.Lookup("jsonschema"); ok { if tag == "" { return nil, fmt.Errorf("empty jsonschema tag on struct field %s.%s", t, field.Name) @@ -152,6 +184,10 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { } default: + if lax { + // Ignore. + return nil, nil + } return nil, fmt.Errorf("type %v is unsupported by jsonschema", t) } if allowNull && s.Type != "" { diff --git a/jsonschema/infer_test.go b/jsonschema/infer_test.go index 106e5375..8c8feec0 100644 --- a/jsonschema/infer_test.go +++ b/jsonschema/infer_test.go @@ -13,8 +13,14 @@ import ( "github.com/modelcontextprotocol/go-sdk/jsonschema" ) -func forType[T any]() *jsonschema.Schema { - s, err := jsonschema.For[T]() +func forType[T any](lax bool) *jsonschema.Schema { + var s *jsonschema.Schema + var err error + if lax { + s, err = jsonschema.ForLax[T]() + } else { + s, err = jsonschema.For[T]() + } if err != nil { panic(err) } @@ -28,104 +34,134 @@ func TestFor(t *testing.T) { B int `jsonschema:"bdesc"` } - tests := []struct { + type test struct { name string got *jsonschema.Schema want *jsonschema.Schema - }{ - {"string", forType[string](), &schema{Type: "string"}}, - {"int", forType[int](), &schema{Type: "integer"}}, - {"int16", forType[int16](), &schema{Type: "integer"}}, - {"uint32", forType[int16](), &schema{Type: "integer"}}, - {"float64", forType[float64](), &schema{Type: "number"}}, - {"bool", forType[bool](), &schema{Type: "boolean"}}, - {"intmap", forType[map[string]int](), &schema{ - Type: "object", - AdditionalProperties: &schema{Type: "integer"}, - }}, - {"anymap", forType[map[string]any](), &schema{ - Type: "object", - AdditionalProperties: &schema{}, - }}, - { - "struct", - forType[struct { - F int `json:"f" jsonschema:"fdesc"` - G []float64 - P *bool `jsonschema:"pdesc"` - Skip string `json:"-"` - NoSkip string `json:",omitempty"` - unexported float64 - unexported2 int `json:"No"` - }](), - &schema{ - Type: "object", - Properties: map[string]*schema{ - "f": {Type: "integer", Description: "fdesc"}, - "G": {Type: "array", Items: &schema{Type: "number"}}, - "P": {Types: []string{"null", "boolean"}, Description: "pdesc"}, - "NoSkip": {Type: "string"}, + } + + tests := func(lax bool) []test { + return []test{ + {"string", forType[string](lax), &schema{Type: "string"}}, + {"int", forType[int](lax), &schema{Type: "integer"}}, + {"int16", forType[int16](lax), &schema{Type: "integer"}}, + {"uint32", forType[int16](lax), &schema{Type: "integer"}}, + {"float64", forType[float64](lax), &schema{Type: "number"}}, + {"bool", forType[bool](lax), &schema{Type: "boolean"}}, + {"intmap", forType[map[string]int](lax), &schema{ + Type: "object", + AdditionalProperties: &schema{Type: "integer"}, + }}, + {"anymap", forType[map[string]any](lax), &schema{ + Type: "object", + AdditionalProperties: &schema{}, + }}, + { + "struct", + forType[struct { + F int `json:"f" jsonschema:"fdesc"` + G []float64 + P *bool `jsonschema:"pdesc"` + Skip string `json:"-"` + NoSkip string `json:",omitempty"` + unexported float64 + unexported2 int `json:"No"` + }](lax), + &schema{ + Type: "object", + Properties: map[string]*schema{ + "f": {Type: "integer", Description: "fdesc"}, + "G": {Type: "array", Items: &schema{Type: "number"}}, + "P": {Types: []string{"null", "boolean"}, Description: "pdesc"}, + "NoSkip": {Type: "string"}, + }, + Required: []string{"f", "G", "P"}, + AdditionalProperties: falseSchema(), }, - Required: []string{"f", "G", "P"}, - AdditionalProperties: falseSchema(), }, - }, - { - "no sharing", - forType[struct{ X, Y int }](), - &schema{ - Type: "object", - Properties: map[string]*schema{ - "X": {Type: "integer"}, - "Y": {Type: "integer"}, + { + "no sharing", + forType[struct{ X, Y int }](lax), + &schema{ + Type: "object", + Properties: map[string]*schema{ + "X": {Type: "integer"}, + "Y": {Type: "integer"}, + }, + Required: []string{"X", "Y"}, + AdditionalProperties: falseSchema(), }, - Required: []string{"X", "Y"}, - AdditionalProperties: falseSchema(), }, - }, - { - "nested and embedded", - forType[struct { - A S - S - }](), - &schema{ - Type: "object", - Properties: map[string]*schema{ - "A": { - Type: "object", - Properties: map[string]*schema{ - "B": {Type: "integer", Description: "bdesc"}, + { + "nested and embedded", + forType[struct { + A S + S + }](lax), + &schema{ + Type: "object", + Properties: map[string]*schema{ + "A": { + Type: "object", + Properties: map[string]*schema{ + "B": {Type: "integer", Description: "bdesc"}, + }, + Required: []string{"B"}, + AdditionalProperties: falseSchema(), }, - Required: []string{"B"}, - AdditionalProperties: falseSchema(), - }, - "S": { - Type: "object", - Properties: map[string]*schema{ - "B": {Type: "integer", Description: "bdesc"}, + "S": { + Type: "object", + Properties: map[string]*schema{ + "B": {Type: "integer", Description: "bdesc"}, + }, + Required: []string{"B"}, + AdditionalProperties: falseSchema(), }, - Required: []string{"B"}, - AdditionalProperties: falseSchema(), }, + Required: []string{"A", "S"}, + AdditionalProperties: falseSchema(), }, - Required: []string{"A", "S"}, - AdditionalProperties: falseSchema(), }, - }, + } } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if diff := cmp.Diff(test.want, test.got, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Fatalf("ForType mismatch (-want +got):\n%s", diff) - } - // These schemas should all resolve. - if _, err := test.got.Resolve(nil); err != nil { - t.Fatalf("Resolving: %v", err) - } - }) + run := func(t *testing.T, tt test) { + if diff := cmp.Diff(tt.want, tt.got, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("ForType mismatch (-want +got):\n%s", diff) + } + // These schemas should all resolve. + if _, err := tt.got.Resolve(nil); err != nil { + t.Fatalf("Resolving: %v", err) + } } + + t.Run("strict", func(t *testing.T) { + for _, test := range tests(false) { + t.Run(test.name, func(t *testing.T) { run(t, test) }) + } + }) + + laxTests := append(tests(true), test{ + "ignore", + forType[struct { + A int + B map[int]int + C func() + }](true), + &schema{ + Type: "object", + Properties: map[string]*schema{ + "A": {Type: "integer"}, + }, + Required: []string{"A"}, + AdditionalProperties: falseSchema(), + }, + }) + t.Run("lax", func(t *testing.T) { + for _, test := range laxTests { + t.Run(test.name, func(t *testing.T) { run(t, test) }) + } + }) } func forErr[T any]() error { From 9ddad03891c508e2dc9a2aaa35bf83f29e77de30 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 21 Jul 2025 11:40:32 -0400 Subject: [PATCH 033/221] mcp: dedup wire type for image and audio (#149) Factor out the wire type into a top-level struct. --- mcp/content.go | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/mcp/content.go b/mcp/content.go index 04888dec..8bf75f0f 100644 --- a/mcp/content.go +++ b/mcp/content.go @@ -63,13 +63,7 @@ func (c *ImageContent) MarshalJSON() ([]byte, error) { if data == nil { data = []byte{} } - wire := struct { - Type string `json:"type"` - MIMEType string `json:"mimeType"` - Data []byte `json:"data"` - Meta Meta `json:"_meta,omitempty"` - Annotations *Annotations `json:"annotations,omitempty"` - }{ + wire := imageAudioWire{ Type: "image", MIMEType: c.MIMEType, Data: data, @@ -100,13 +94,7 @@ func (c AudioContent) MarshalJSON() ([]byte, error) { if data == nil { data = []byte{} } - wire := struct { - Type string `json:"type"` - MIMEType string `json:"mimeType"` - Data []byte `json:"data"` - Meta Meta `json:"_meta,omitempty"` - Annotations *Annotations `json:"annotations,omitempty"` - }{ + wire := imageAudioWire{ Type: "audio", MIMEType: c.MIMEType, Data: data, @@ -123,6 +111,15 @@ func (c *AudioContent) fromWire(wire *wireContent) { c.Annotations = wire.Annotations } +// Custom wire format to ensure required fields are always included, even when empty. +type imageAudioWire struct { + Type string `json:"type"` + MIMEType string `json:"mimeType"` + Data []byte `json:"data"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` +} + // ResourceLink is a link to a resource type ResourceLink struct { URI string From 1465442b250629a7108d05f168c7f06faad707bc Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 22 Jul 2025 10:21:05 -0400 Subject: [PATCH 034/221] mcp: add EventStore (#152) Introduce an EventStore interface to store events for resumable streams. Provide an in-memory implementation. Still to do: connect to streaming transports. For #10 --- mcp/event.go | 290 +++++++++++++++++++++++++++++++++++++++++ mcp/event_test.go | 167 ++++++++++++++++++++++++ mcp/streamable.go | 32 ++--- mcp/streamable_test.go | 2 +- 4 files changed, 474 insertions(+), 17 deletions(-) diff --git a/mcp/event.go b/mcp/event.go index 4562b811..fbbe1941 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -10,14 +10,22 @@ package mcp import ( "bufio" "bytes" + "context" "errors" "fmt" "io" "iter" + "maps" "net/http" + "slices" "strings" + "sync" ) +// If true, MemoryEventStore will do frequent validation to check invariants, slowing it down. +// Remove when we're confident in the code. +const validateMemoryEventStore = true + // An Event is a server-sent event. // See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#fields. type Event struct { @@ -136,3 +144,285 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { } } } + +// An EventStore tracks data for SSE streams. +// A single EventStore suffices for all sessions, since session IDs are +// globally unique. So one EventStore can be created per process, for +// all Servers in the process. +// Such a store is able to bound resource usage for the entire process. +// +// All of an EventStore's methods must be safe for use by multiple goroutines. +type EventStore interface { + // AppendEvent appends data for an outgoing event to given stream, which is part of the + // given session. It returns the index of the event in the stream, suitable for constructing + // an event ID to send to the client. + AppendEvent(_ context.Context, sessionID string, _ StreamID, data []byte) (int, error) + + // After returns an iterator over the data for the given session and stream, beginning + // just after the given index. + // Once the iterator yields a non-nil error, it will stop. + // After's iterator must return an error immediately if any data after index was + // dropped; it must not return partial results. + After(_ context.Context, sessionID string, _ StreamID, index int) iter.Seq2[[]byte, error] + + // StreamClosed informs the store that the given stream is finished. + // A store cannot rely on this method being called for cleanup. It should institute + // additional mechanisms, such as timeouts, to reclaim storage. + StreamClosed(_ context.Context, sessionID string, streamID StreamID) error + + // SessionClosed informs the store that the given session is finished, along + // with all of its streams. + // A store cannot rely on this method being called for cleanup. It should institute + // additional mechanisms, such as timeouts, to reclaim storage. + SessionClosed(_ context.Context, sessionID string) error +} + +// A dataList is a list of []byte. +// The zero dataList is ready to use. +type dataList struct { + size int // total size of data bytes + first int // the stream index of the first element in data + data [][]byte +} + +func (dl *dataList) appendData(d []byte) { + // If we allowed empty data, we would consume memory without incrementing the size. + // We could of course account for that, but we keep it simple and assume there is no + // empty data. + if len(d) == 0 { + panic("empty data item") + } + dl.data = append(dl.data, d) + dl.size += len(d) +} + +// removeFirst removes the first data item in dl, returning the size of the item. +// It panics if dl is empty. +func (dl *dataList) removeFirst() int { + if len(dl.data) == 0 { + panic("empty dataList") + } + r := len(dl.data[0]) + dl.size -= r + dl.data[0] = nil // help GC + dl.data = dl.data[1:] + dl.first++ + return r +} + +// lastIndex returns the index of the last data item in dl. +// It panics if there are none. +func (dl *dataList) lastIndex() int { + if len(dl.data) == 0 { + panic("empty dataList") + } + return dl.first + len(dl.data) - 1 +} + +// A MemoryEventStore is an [EventStore] backed by memory. +type MemoryEventStore struct { + mu sync.Mutex + maxBytes int // max total size of all data + nBytes int // current total size of all data + store map[string]map[StreamID]*dataList // session ID -> stream ID -> *dataList +} + +// MemoryEventStoreOptions are options for a [MemoryEventStore]. +type MemoryEventStoreOptions struct{} + +// MaxBytes returns the maximum number of bytes that the store will retain before +// purging data. +func (s *MemoryEventStore) MaxBytes() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.maxBytes +} + +// SetMaxBytes sets the maximum number of bytes the store will retain before purging +// data. The argument must not be negative. If it is zero, a suitable default will be used. +// SetMaxBytes can be called at any time. The size of the store will be adjusted +// immediately. +func (s *MemoryEventStore) SetMaxBytes(n int) { + s.mu.Lock() + defer s.mu.Unlock() + switch { + case n < 0: + panic("negative argument") + case n == 0: + s.maxBytes = defaultMaxBytes + default: + s.maxBytes = n + } + s.purge() +} + +const defaultMaxBytes = 10 << 20 // 10 MiB + +// NewMemoryEventStore creates a [MemoryEventStore] with the default value +// for MaxBytes. +func NewMemoryEventStore(opts *MemoryEventStoreOptions) *MemoryEventStore { + return &MemoryEventStore{ + maxBytes: defaultMaxBytes, + store: make(map[string]map[StreamID]*dataList), + } +} + +// AppendEvent implements [EventStore.AppendEvent] by recording data +// in memory. +func (s *MemoryEventStore) AppendEvent(_ context.Context, sessionID string, streamID StreamID, data []byte) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + + streamMap, ok := s.store[sessionID] + if !ok { + streamMap = make(map[StreamID]*dataList) + s.store[sessionID] = streamMap + } + dl, ok := streamMap[streamID] + if !ok { + dl = &dataList{} + streamMap[streamID] = dl + } + // Purge before adding, so at least the current data item will be present. + // (That could result in nBytes > maxBytes, but we'll live with that.) + s.purge() + dl.appendData(data) + s.nBytes += len(data) + return dl.lastIndex(), nil +} + +// After implements [EventStore.After]. +func (s *MemoryEventStore) After(_ context.Context, sessionID string, streamID StreamID, index int) iter.Seq2[[]byte, error] { + // Return the data items to yield. + // We must copy, because dataList.removeFirst nils out slice elements. + copyData := func() ([][]byte, error) { + s.mu.Lock() + defer s.mu.Unlock() + streamMap, ok := s.store[sessionID] + if !ok { + return nil, fmt.Errorf("MemoryEventStore.After: unknown session ID %q", sessionID) + } + dl, ok := streamMap[streamID] + if !ok { + return nil, fmt.Errorf("MemoryEventStore.After: unknown stream ID %v in session %q", streamID, sessionID) + } + if dl.first > index { + return nil, fmt.Errorf("MemoryEventStore.After: data purged at index %d, stream ID %v, session %q", index, streamID, sessionID) + } + return slices.Clone(dl.data[index-dl.first:]), nil + } + + return func(yield func([]byte, error) bool) { + ds, err := copyData() + if err != nil { + yield(nil, err) + return + } + for _, d := range ds { + if !yield(d, nil) { + return + } + } + } +} + +// StreamClosed implements [EventStore.StreamClosed]. +func (s *MemoryEventStore) StreamClosed(_ context.Context, sessionID string, streamID StreamID) error { + if sessionID == "" { + panic("empty sessionID") + } + + s.mu.Lock() + defer s.mu.Unlock() + + sm := s.store[sessionID] + dl := sm[streamID] + s.nBytes -= dl.size + delete(sm, streamID) + if len(sm) == 0 { + delete(s.store, sessionID) + } + s.validate() + return nil +} + +// SessionClosed implements [EventStore.SessionClosed]. +func (s *MemoryEventStore) SessionClosed(_ context.Context, sessionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + for _, dl := range s.store[sessionID] { + s.nBytes -= dl.size + } + delete(s.store, sessionID) + s.validate() + return nil +} + +// purge removes data until no more than s.maxBytes bytes are in use. +// It must be called with s.mu held. +func (s *MemoryEventStore) purge() { + // Remove the first element of every dataList until below the max. + for s.nBytes > s.maxBytes { + changed := false + for _, sm := range s.store { + for _, dl := range sm { + if dl.size > 0 { + r := dl.removeFirst() + if r > 0 { + changed = true + s.nBytes -= r + } + } + } + } + if !changed { + panic("no progress during purge") + } + } + s.validate() +} + +// validate checks that the store's data structures are valid. +// It must be called with s.mu held. +func (s *MemoryEventStore) validate() { + if !validateMemoryEventStore { + return + } + // Check that we're accounting for the size correctly. + n := 0 + for _, sm := range s.store { + for _, dl := range sm { + for _, d := range dl.data { + n += len(d) + } + } + } + if n != s.nBytes { + panic("sizes don't add up") + } +} + +// debugString returns a string containing the state of s. +// Used in tests. +func (s *MemoryEventStore) debugString() string { + s.mu.Lock() + defer s.mu.Unlock() + var b strings.Builder + for i, sess := range slices.Sorted(maps.Keys(s.store)) { + if i > 0 { + fmt.Fprintf(&b, "; ") + } + sm := s.store[sess] + for i, sid := range slices.Sorted(maps.Keys(sm)) { + if i > 0 { + fmt.Fprintf(&b, "; ") + } + dl := sm[sid] + fmt.Fprintf(&b, "%s %d first=%d", sess, sid, dl.first) + for _, d := range dl.data { + fmt.Fprintf(&b, " %s", d) + } + } + } + return b.String() +} diff --git a/mcp/event_test.go b/mcp/event_test.go index 2e2e3d26..9b01555f 100644 --- a/mcp/event_test.go +++ b/mcp/event_test.go @@ -5,6 +5,9 @@ package mcp import ( + "context" + "fmt" + "slices" "strings" "testing" ) @@ -97,3 +100,167 @@ func TestScanEvents(t *testing.T) { }) } } + +func TestMemoryEventStoreState(t *testing.T) { + ctx := context.Background() + + appendEvent := func(s *MemoryEventStore, sess string, str StreamID, data string) { + if _, err := s.AppendEvent(ctx, sess, str, []byte(data)); err != nil { + t.Fatal(err) + } + } + + for _, tt := range []struct { + name string + actions func(*MemoryEventStore) + want string // output of debugString + wantSize int // value of nBytes + }{ + { + "appends", + func(s *MemoryEventStore) { + appendEvent(s, "S1", 1, "d1") + appendEvent(s, "S1", 2, "d2") + appendEvent(s, "S1", 1, "d3") + appendEvent(s, "S2", 8, "d4") + }, + "S1 1 first=0 d1 d3; S1 2 first=0 d2; S2 8 first=0 d4", + 8, + }, + { + "stream close", + func(s *MemoryEventStore) { + appendEvent(s, "S1", 1, "d1") + appendEvent(s, "S1", 2, "d2") + appendEvent(s, "S1", 1, "d3") + appendEvent(s, "S2", 8, "d4") + s.StreamClosed(ctx, "S1", 1) + }, + "S1 2 first=0 d2; S2 8 first=0 d4", + 4, + }, + { + "session close", + func(s *MemoryEventStore) { + appendEvent(s, "S1", 1, "d1") + appendEvent(s, "S1", 2, "d2") + appendEvent(s, "S1", 1, "d3") + appendEvent(s, "S2", 8, "d4") + s.SessionClosed(ctx, "S1") + }, + "S2 8 first=0 d4", + 2, + }, + { + "purge", + func(s *MemoryEventStore) { + appendEvent(s, "S1", 1, "d1") + appendEvent(s, "S1", 2, "d2") + appendEvent(s, "S1", 1, "d3") + appendEvent(s, "S2", 8, "d4") + // We are using 8 bytes (d1,d2, d3, d4). + // To purge 6, we remove the first of each stream, leaving only d3. + s.SetMaxBytes(2) + }, + // The other streams remain, because we may add to them. + "S1 1 first=1 d3; S1 2 first=1; S2 8 first=1", + 2, + }, + { + "purge append", + func(s *MemoryEventStore) { + appendEvent(s, "S1", 1, "d1") + appendEvent(s, "S1", 2, "d2") + appendEvent(s, "S1", 1, "d3") + appendEvent(s, "S2", 8, "d4") + s.SetMaxBytes(2) + // Up to here, identical to the "purge" case. + // Each of these additions will result in a purge. + appendEvent(s, "S1", 2, "d5") // remove d3 + appendEvent(s, "S1", 2, "d6") // remove d5 + }, + "S1 1 first=2; S1 2 first=2 d6; S2 8 first=1", + 2, + }, + { + "purge resize append", + func(s *MemoryEventStore) { + appendEvent(s, "S1", 1, "d1") + appendEvent(s, "S1", 2, "d2") + appendEvent(s, "S1", 1, "d3") + appendEvent(s, "S2", 8, "d4") + s.SetMaxBytes(2) + // Up to here, identical to the "purge" case. + s.SetMaxBytes(6) // make room + appendEvent(s, "S1", 2, "d5") + appendEvent(s, "S1", 2, "d6") + }, + // The other streams remain, because we may add to them. + "S1 1 first=1 d3; S1 2 first=1 d5 d6; S2 8 first=1", + 6, + }, + } { + t.Run(tt.name, func(t *testing.T) { + s := NewMemoryEventStore(nil) + tt.actions(s) + got := s.debugString() + if got != tt.want { + t.Errorf("\ngot %s\nwant %s", got, tt.want) + } + if g, w := s.nBytes, tt.wantSize; g != w { + t.Errorf("got size %d, want %d", g, w) + } + }) + } +} + +func TestMemoryEventStoreAfter(t *testing.T) { + ctx := context.Background() + s := NewMemoryEventStore(nil) + s.SetMaxBytes(4) + s.AppendEvent(ctx, "S1", 1, []byte("d1")) + s.AppendEvent(ctx, "S1", 1, []byte("d2")) + s.AppendEvent(ctx, "S1", 1, []byte("d3")) + s.AppendEvent(ctx, "S1", 2, []byte("d4")) // will purge d1 + want := "S1 1 first=1 d2 d3; S1 2 first=0 d4" + if got := s.debugString(); got != want { + t.Fatalf("got state %q, want %q", got, want) + } + + for _, tt := range []struct { + sessionID string + streamID StreamID + index int + want []string + wantErr string // if non-empty, error should contain this string + }{ + {"S1", 1, 0, nil, "purge"}, + {"S1", 1, 1, []string{"d2", "d3"}, ""}, + {"S1", 1, 2, []string{"d3"}, ""}, + {"S1", 2, 0, []string{"d4"}, ""}, + {"S1", 3, 0, nil, "unknown stream ID"}, + {"S2", 0, 0, nil, "unknown session ID"}, + } { + t.Run(fmt.Sprintf("%s-%d-%d", tt.sessionID, tt.streamID, tt.index), func(t *testing.T) { + var got []string + for d, err := range s.After(ctx, tt.sessionID, tt.streamID, tt.index) { + if err != nil { + if tt.wantErr == "" { + t.Fatalf("unexpected error %q", err) + } else if g := err.Error(); !strings.Contains(g, tt.wantErr) { + t.Fatalf("got error %q, want it to contain %q", g, tt.wantErr) + } else { + return + } + } + got = append(got, string(d)) + } + if tt.wantErr != "" { + t.Fatalf("expected error containing %q, got nil", tt.wantErr) + } + if !slices.Equal(got, tt.want) { + t.Errorf("got %v, want %v", got, tt.want) + } + }) + } +} diff --git a/mcp/streamable.go b/mcp/streamable.go index ef740d37..728da851 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -163,10 +163,10 @@ func NewStreamableServerTransport(sessionID string) *StreamableServerTransport { id: sessionID, incoming: make(chan jsonrpc.Message, 10), done: make(chan struct{}), - outgoingMessages: make(map[streamID][]*streamableMsg), - signals: make(map[streamID]chan struct{}), - requestStreams: make(map[jsonrpc.ID]streamID), - streamRequests: make(map[streamID]map[jsonrpc.ID]struct{}), + outgoingMessages: make(map[StreamID][]*streamableMsg), + signals: make(map[StreamID]chan struct{}), + requestStreams: make(map[jsonrpc.ID]StreamID), + streamRequests: make(map[StreamID]map[jsonrpc.ID]struct{}), } } @@ -212,7 +212,7 @@ type StreamableServerTransport struct { // // TODO(rfindley): garbage collect this data. For now, we save all outgoingMessages // messages for the lifespan of the transport. - outgoingMessages map[streamID][]*streamableMsg + outgoingMessages map[StreamID][]*streamableMsg // signals maps a logical stream ID to a 1-buffered channel, owned by an // incoming HTTP request, that signals that there are messages available to @@ -223,14 +223,14 @@ type StreamableServerTransport struct { // // Lifecycle: signals persists for the duration of an HTTP POST or GET // request for the given streamID. - signals map[streamID]chan struct{} + signals map[StreamID]chan struct{} // requestStreams maps incoming requests to their logical stream ID. // // Lifecycle: requestStreams persists for the duration of the session. // // TODO(rfindley): clean up once requests are handled. - requestStreams map[jsonrpc.ID]streamID + requestStreams map[jsonrpc.ID]StreamID // streamRequests tracks the set of unanswered incoming RPCs for each logical // stream. @@ -241,10 +241,10 @@ type StreamableServerTransport struct { // Lifecycle: streamRequests values persist as until the requests have been // replied to by the server. Notably, NOT until they are sent to an HTTP // response, as delivery is not guaranteed. - streamRequests map[streamID]map[jsonrpc.ID]struct{} + streamRequests map[StreamID]map[jsonrpc.ID]struct{} } -type streamID int64 +type StreamID int64 // a streamableMsg is an SSE event with an index into its logical stream. type streamableMsg struct { @@ -298,7 +298,7 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) { // connID 0 corresponds to the default GET request. - id, nextIdx := streamID(0), 0 + id, nextIdx := StreamID(0), 0 if len(req.Header.Values("Last-Event-ID")) > 0 { eid := req.Header.Get("Last-Event-ID") var ok bool @@ -352,7 +352,7 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R } // Update accounting for this request. - id := streamID(t.nextStreamID.Add(1)) + id := StreamID(t.nextStreamID.Add(1)) signal := make(chan struct{}, 1) t.mu.Lock() if len(requests) > 0 { @@ -376,7 +376,7 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R t.streamResponse(w, req, id, 0, signal) } -func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *http.Request, id streamID, nextIndex int, signal chan struct{}) { +func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *http.Request, id StreamID, nextIndex int, signal chan struct{}) { defer func() { t.mu.Lock() delete(t.signals, id) @@ -463,7 +463,7 @@ stream: // streamID and message index idx. // // See also [parseEventID]. -func formatEventID(sid streamID, idx int) string { +func formatEventID(sid StreamID, idx int) string { return fmt.Sprintf("%d_%d", sid, idx) } @@ -471,7 +471,7 @@ func formatEventID(sid streamID, idx int) string { // index. // // See also [formatEventID]. -func parseEventID(eventID string) (sid streamID, idx int, ok bool) { +func parseEventID(eventID string) (sid StreamID, idx int, ok bool) { parts := strings.Split(eventID, "_") if len(parts) != 2 { return 0, 0, false @@ -484,7 +484,7 @@ func parseEventID(eventID string) (sid streamID, idx int, ok bool) { if err != nil || idx < 0 { return 0, 0, false } - return streamID(stream), idx, true + return StreamID(stream), idx, true } // Read implements the [Connection] interface. @@ -523,7 +523,7 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa // // For messages sent outside of a request context, this is the default // connection 0. - var forConn streamID + var forConn StreamID if forRequest.IsValid() { t.mu.Lock() forConn = t.requestStreams[forRequest] diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 05bf59e7..87664da9 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -555,7 +555,7 @@ func mustMarshal(t *testing.T, v any) json.RawMessage { func TestEventID(t *testing.T) { tests := []struct { - sid streamID + sid StreamID idx int }{ {0, 0}, From a911cd0ffde0f29abe22d0eef78ba333e269809b Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Tue, 22 Jul 2025 15:58:29 -0400 Subject: [PATCH 035/221] mcp/streamable: add resumability for the Streamable transport (#133) This CL implements a retry mechanism to resume SSE streams to recover from network failures. For #10 --- mcp/streamable.go | 219 ++++++++++++++++++++++++++++++++++++----- mcp/streamable_test.go | 125 +++++++++++++++++++++++ 2 files changed, 318 insertions(+), 26 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 728da851..f7c6ed63 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -9,11 +9,14 @@ import ( "context" "fmt" "io" + "math" + "math/rand/v2" "net/http" "strconv" "strings" "sync" "sync/atomic" + "time" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" @@ -597,12 +600,39 @@ type StreamableClientTransport struct { opts StreamableClientTransportOptions } +// StreamableReconnectOptions defines parameters for client reconnect attempts. +type StreamableReconnectOptions struct { + // MaxRetries is the maximum number of times to attempt a reconnect before giving up. + // A value of 0 or less means never retry. + MaxRetries int + + // growFactor is the multiplicative factor by which the delay increases after each attempt. + // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. + // It must be 1.0 or greater if MaxRetries is greater than 0. + growFactor float64 + + // initialDelay is the base delay for the first reconnect attempt. + initialDelay time.Duration + + // maxDelay caps the backoff delay, preventing it from growing indefinitely. + maxDelay time.Duration +} + +// DefaultReconnectOptions provides sensible defaults for reconnect logic. +var DefaultReconnectOptions = &StreamableReconnectOptions{ + MaxRetries: 5, + growFactor: 1.5, + initialDelay: 1 * time.Second, + maxDelay: 30 * time.Second, +} + // StreamableClientTransportOptions provides options for the // [NewStreamableClientTransport] constructor. type StreamableClientTransportOptions struct { // HTTPClient is the client to use for making HTTP requests. If nil, // http.DefaultClient is used. - HTTPClient *http.Client + HTTPClient *http.Client + ReconnectOptions *StreamableReconnectOptions } // NewStreamableClientTransport returns a new client transport that connects to @@ -628,22 +658,37 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er if client == nil { client = http.DefaultClient } - return &streamableClientConn{ - url: t.url, - client: client, - incoming: make(chan []byte, 100), - done: make(chan struct{}), - }, nil + reconnOpts := t.opts.ReconnectOptions + if reconnOpts == nil { + reconnOpts = DefaultReconnectOptions + } + // Create a new cancellable context that will manage the connection's lifecycle. + // This is crucial for cleanly shutting down the background SSE listener by + // cancelling its blocking network operations, which prevents hangs on exit. + connCtx, cancel := context.WithCancel(context.Background()) + conn := &streamableClientConn{ + url: t.url, + client: client, + incoming: make(chan []byte, 100), + done: make(chan struct{}), + ReconnectOptions: reconnOpts, + ctx: connCtx, + cancel: cancel, + } + return conn, nil } type streamableClientConn struct { - url string - client *http.Client - incoming chan []byte - done chan struct{} + url string + client *http.Client + incoming chan []byte + done chan struct{} + ReconnectOptions *StreamableReconnectOptions closeOnce sync.Once closeErr error + ctx context.Context + cancel context.CancelFunc mu sync.Mutex protocolVersion string @@ -665,6 +710,12 @@ func (c *streamableClientConn) SessionID() string { // Read implements the [Connection] interface. func (s *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { + s.mu.Lock() + err := s.err + s.mu.Unlock() + if err != nil { + return nil, err + } select { case <-ctx.Done(): return nil, ctx.Err() @@ -745,6 +796,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string sessionID = resp.Header.Get(sessionIDHeader) switch ct := resp.Header.Get("Content-Type"); ct { case "text/event-stream": + // Section 2.1: The SSE stream is initiated after a POST. go s.handleSSE(resp) case "application/json": // TODO: read the body and send to s.incoming (in a select that also receives from s.done). @@ -757,34 +809,115 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string return sessionID, nil } -func (s *streamableClientConn) handleSSE(resp *http.Response) { +// handleSSE manages the entire lifecycle of an SSE connection. It processes +// an incoming Server-Sent Events stream and automatically handles reconnection +// logic if the stream breaks. +func (s *streamableClientConn) handleSSE(initialResp *http.Response) { + resp := initialResp + var lastEventID string + + for { + eventID, clientClosed := s.processStream(resp) + lastEventID = eventID + + // If the connection was closed by the client, we're done. + if clientClosed { + return + } + + // The stream was interrupted or ended by the server. Attempt to reconnect. + newResp, err := s.reconnect(lastEventID) + if err != nil { + // All reconnection attempts failed. Set the final error, close the + // connection, and exit the goroutine. + s.mu.Lock() + s.err = err + s.mu.Unlock() + s.Close() + return + } + + // Reconnection was successful. Continue the loop with the new response. + resp = newResp + } +} + +// processStream reads from a single response body, sending events to the +// incoming channel. It returns the ID of the last processed event, any error +// that occurred, and a flag indicating if the connection was closed by the client. +func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) { defer resp.Body.Close() - done := make(chan struct{}) - go func() { - defer close(done) - for evt, err := range scanEvents(resp.Body) { + for evt, err := range scanEvents(resp.Body) { + if err != nil { + return lastEventID, false + } + + if evt.ID != "" { + lastEventID = evt.ID + } + + select { + case s.incoming <- evt.Data: + case <-s.done: + // The connection was closed by the client; exit gracefully. + return lastEventID, true + } + } + + // The loop finished without an error, indicating the server closed the stream. + // We'll attempt to reconnect, so this is not a client-side close. + return lastEventID, false +} + +// reconnect handles the logic of retrying a connection with an exponential +// backoff strategy. It returns a new, valid HTTP response if successful, or +// an error if all retries are exhausted. +func (s *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) { + var finalErr error + + for attempt := 0; attempt < s.ReconnectOptions.MaxRetries; attempt++ { + select { + case <-s.done: + return nil, fmt.Errorf("connection closed by client during reconnect") + case <-time.After(calculateReconnectDelay(s.ReconnectOptions, attempt)): + resp, err := s.establishSSE(lastEventID) if err != nil { - // TODO: surface this error; possibly break the stream - return + finalErr = err // Store the error and try again. + continue } - select { - case <-s.done: - return - case s.incoming <- evt.Data: + + if !isResumable(resp) { + // The server indicated we should not continue. + resp.Body.Close() + return nil, fmt.Errorf("reconnection failed with unresumable status: %s", resp.Status) } + + return resp, nil } - }() + } + // If the loop completes, all retries have failed. + if finalErr != nil { + return nil, fmt.Errorf("connection failed after %d attempts: %w", s.ReconnectOptions.MaxRetries, finalErr) + } + return nil, fmt.Errorf("connection failed after %d attempts", s.ReconnectOptions.MaxRetries) +} - select { - case <-s.done: - case <-done: +// isResumable checks if an HTTP response indicates a valid SSE stream that can be processed. +func isResumable(resp *http.Response) bool { + // Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint. + if resp.StatusCode == http.StatusMethodNotAllowed { + return false } + + return strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") } // Close implements the [Connection] interface. func (s *streamableClientConn) Close() error { s.closeOnce.Do(func() { + // Cancel any hanging network requests. + s.cancel() close(s.done) req, err := http.NewRequest(http.MethodDelete, s.url, nil) @@ -803,3 +936,37 @@ func (s *streamableClientConn) Close() error { }) return s.closeErr } + +// establishSSE establishes the persistent SSE listening stream. +// It is used for reconnect attempts using the Last-Event-ID header to +// resume a broken stream where it left off. +func (s *streamableClientConn) establishSSE(lastEventID string) (*http.Response, error) { + req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, s.url, nil) + if err != nil { + return nil, err + } + s.mu.Lock() + if s._sessionID != "" { + req.Header.Set("Mcp-Session-Id", s._sessionID) + } + s.mu.Unlock() + if lastEventID != "" { + req.Header.Set("Last-Event-ID", lastEventID) + } + req.Header.Set("Accept", "text/event-stream") + + return s.client.Do(req) +} + +// calculateReconnectDelay calculates a delay using exponential backoff with full jitter. +func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time.Duration { + // Calculate the exponential backoff using the grow factor. + backoffDuration := time.Duration(float64(opts.initialDelay) * math.Pow(opts.growFactor, float64(attempt))) + // Cap the backoffDuration at maxDelay. + backoffDuration = min(backoffDuration, opts.maxDelay) + + // Use a full jitter using backoffDuration + jitter := rand.N(backoffDuration) + + return backoffDuration + jitter +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 87664da9..864265e5 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -10,19 +10,23 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/http/cookiejar" "net/http/httptest" + "net/http/httputil" "net/url" "strings" "sync" "sync/atomic" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/modelcontextprotocol/go-sdk/jsonschema" ) func TestStreamableTransports(t *testing.T) { @@ -105,6 +109,127 @@ func TestStreamableTransports(t *testing.T) { } } +// TestClientReplay verifies that the client can recover from a +// mid-stream network failure and receive replayed messages. It uses a proxy +// that is killed and restarted to simulate a recoverable network outage. +func TestClientReplay(t *testing.T) { + notifications := make(chan string) + // 1. Configure the real MCP server. + server := NewServer(testImpl, nil) + + // Use a channel to synchronize the server's message sending with the test's + // proxy-killing action. + serverReadyToKillProxy := make(chan struct{}) + serverClosed := make(chan struct{}) + server.AddTool(&Tool{Name: "multiMessageTool", InputSchema: &jsonschema.Schema{}}, + func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { + go func() { + bgCtx := context.Background() + // Send the first two messages immediately. + ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg1"}) + ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"}) + + // Signal the test that it can now kill the proxy. + close(serverReadyToKillProxy) + <-serverClosed + + // These messages should be queued for replay by the server after + // the client's connection drops. + ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg3"}) + ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"}) + }() + return &CallToolResult{}, nil + }) + realServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) + defer realServer.Close() + realServerURL, err := url.Parse(realServer.URL) + if err != nil { + t.Fatalf("Failed to parse real server URL: %v", err) + } + + // 2. Configure a proxy that sits between the client and the real server. + proxyHandler := httputil.NewSingleHostReverseProxy(realServerURL) + proxy := httptest.NewServer(proxyHandler) + proxyAddr := proxy.Listener.Addr().String() // Get the address to restart it later. + + // 3. Configure the client to connect to the proxy with default options. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + client := NewClient(testImpl, &ClientOptions{ + ProgressNotificationHandler: func(ctx context.Context, cc *ClientSession, params *ProgressNotificationParams) { + notifications <- params.Message + }}) + clientSession, err := client.Connect(ctx, NewStreamableClientTransport(proxy.URL, nil)) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer clientSession.Close() + clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"}) + + // 4. Read and verify messages until the server signals it's ready for the proxy kill. + receivedNotifications := readProgressNotifications(t, ctx, notifications, 2) + + wantReceived := []string{"msg1", "msg2"} + if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" { + t.Errorf("Received notifications mismatch (-want +got):\n%s", diff) + } + + select { + case <-serverReadyToKillProxy: + // Server has sent the first two messages and is paused. + case <-ctx.Done(): + t.Fatalf("Context timed out before server was ready to kill proxy") + } + + // 5. Simulate a total network failure by closing the proxy. + t.Log("--- Killing proxy to simulate network failure ---") + proxy.CloseClientConnections() + proxy.Close() + close(serverClosed) + + // 6. Simulate network recovery by restarting the proxy on the same address. + t.Logf("--- Restarting proxy on %s ---", proxyAddr) + listener, err := net.Listen("tcp", proxyAddr) + if err != nil { + t.Fatalf("Failed to listen on proxy address: %v", err) + } + restartedProxy := &http.Server{Handler: proxyHandler} + go restartedProxy.Serve(listener) + defer restartedProxy.Close() + + // 7. Continue reading from the same connection object. + // Its internal logic should successfully retry, reconnect to the new proxy, + // and receive the replayed messages. + recoveredNotifications := readProgressNotifications(t, ctx, notifications, 2) + + // 8. Verify the correct messages were received on the recovered connection. + wantRecovered := []string{"msg3", "msg4"} + + if diff := cmp.Diff(wantRecovered, recoveredNotifications); diff != "" { + t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) + } +} + +// Helper to read a specific number of progress notifications. +func readProgressNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string { + t.Helper() + var collectedNotifications []string + for { + select { + case n := <-notifications: + collectedNotifications = append(collectedNotifications, n) + if len(collectedNotifications) == count { + return collectedNotifications + } + case <-ctx.Done(): + if len(collectedNotifications) != count { + t.Fatalf("readProgressNotifications(): did not receive expected notifications, got %d, want %d", len(collectedNotifications), count) + } + return collectedNotifications + } + } +} + func TestStreamableServerTransport(t *testing.T) { // This test checks detailed behavior of the streamable server transport, by // faking the behavior of a streamable client using a sequence of HTTP From 4529904c7a2bbe8ebfe497aa92f0044e3c1d52e4 Mon Sep 17 00:00:00 2001 From: ln-12 <36760115+ln-12@users.noreply.github.com> Date: Wed, 23 Jul 2025 13:17:57 +0200 Subject: [PATCH 036/221] Fix minor issues in design.md (#158) design.md: minor fixes --- design/design.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/design/design.md b/design/design.md index 8ab7c152..8804292e 100644 --- a/design/design.md +++ b/design/design.md @@ -471,7 +471,7 @@ server.AddReceivingMiddleware(withLogging) #### Rate Limiting -Rate limiting can be configured using middleware. Please see [examples/rate-limiting](] for an example on how to implement this. +Rate limiting can be configured using middleware. Please see [examples/rate-limiting]() for an example on how to implement this. ### Errors @@ -609,7 +609,7 @@ A tool handler accepts `CallToolParams` and returns a `CallToolResult`. However, ```go type CallToolParamsFor[In any] struct { Meta Meta `json:"_meta,omitempty"` - Arguments In `json:"arguments,omitempty"` + Arguments In `json:"arguments,omitempty"` Name string `json:"name"` } From ea6162ca1c3fd49ea66770cbf65185dcba6b1903 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 23 Jul 2025 13:42:50 -0400 Subject: [PATCH 037/221] mcp: incorporate EventStore into StreamableServerTransport (#156) Use the EventStore interface to implement resumption on the server side. --- mcp/event.go | 64 ++++++------------ mcp/event_test.go | 30 +++------ mcp/server.go | 2 + mcp/streamable.go | 165 +++++++++++++++++++++++++++------------------- 4 files changed, 128 insertions(+), 133 deletions(-) diff --git a/mcp/event.go b/mcp/event.go index fbbe1941..9092da76 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -153,10 +153,9 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { // // All of an EventStore's methods must be safe for use by multiple goroutines. type EventStore interface { - // AppendEvent appends data for an outgoing event to given stream, which is part of the - // given session. It returns the index of the event in the stream, suitable for constructing - // an event ID to send to the client. - AppendEvent(_ context.Context, sessionID string, _ StreamID, data []byte) (int, error) + // Append appends data for an outgoing event to given stream, which is part of the + // given session. + Append(_ context.Context, sessionID string, _ StreamID, data []byte) error // After returns an iterator over the data for the given session and stream, beginning // just after the given index. @@ -165,16 +164,15 @@ type EventStore interface { // dropped; it must not return partial results. After(_ context.Context, sessionID string, _ StreamID, index int) iter.Seq2[[]byte, error] - // StreamClosed informs the store that the given stream is finished. - // A store cannot rely on this method being called for cleanup. It should institute - // additional mechanisms, such as timeouts, to reclaim storage. - StreamClosed(_ context.Context, sessionID string, streamID StreamID) error - // SessionClosed informs the store that the given session is finished, along // with all of its streams. // A store cannot rely on this method being called for cleanup. It should institute // additional mechanisms, such as timeouts, to reclaim storage. + // SessionClosed(_ context.Context, sessionID string) error + + // There is no StreamClosed method. A server doesn't know when a stream is finished, because + // the client can always send a GET with a Last-Event-ID referring to the stream. } // A dataList is a list of []byte. @@ -210,15 +208,6 @@ func (dl *dataList) removeFirst() int { return r } -// lastIndex returns the index of the last data item in dl. -// It panics if there are none. -func (dl *dataList) lastIndex() int { - if len(dl.data) == 0 { - panic("empty dataList") - } - return dl.first + len(dl.data) - 1 -} - // A MemoryEventStore is an [EventStore] backed by memory. type MemoryEventStore struct { mu sync.Mutex @@ -267,9 +256,8 @@ func NewMemoryEventStore(opts *MemoryEventStoreOptions) *MemoryEventStore { } } -// AppendEvent implements [EventStore.AppendEvent] by recording data -// in memory. -func (s *MemoryEventStore) AppendEvent(_ context.Context, sessionID string, streamID StreamID, data []byte) (int, error) { +// Append implements [EventStore.Append] by recording data in memory. +func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID StreamID, data []byte) error { s.mu.Lock() defer s.mu.Unlock() @@ -288,9 +276,13 @@ func (s *MemoryEventStore) AppendEvent(_ context.Context, sessionID string, stre s.purge() dl.appendData(data) s.nBytes += len(data) - return dl.lastIndex(), nil + return nil } +// ErrEventsPurged is the error that [EventStore.After] should return if the event just after the +// index is no longer available. +var ErrEventsPurged = errors.New("data purged") + // After implements [EventStore.After]. func (s *MemoryEventStore) After(_ context.Context, sessionID string, streamID StreamID, index int) iter.Seq2[[]byte, error] { // Return the data items to yield. @@ -306,10 +298,12 @@ func (s *MemoryEventStore) After(_ context.Context, sessionID string, streamID S if !ok { return nil, fmt.Errorf("MemoryEventStore.After: unknown stream ID %v in session %q", streamID, sessionID) } - if dl.first > index { - return nil, fmt.Errorf("MemoryEventStore.After: data purged at index %d, stream ID %v, session %q", index, streamID, sessionID) + start := index + 1 + if dl.first > start { + return nil, fmt.Errorf("MemoryEventStore.After: index %d, stream ID %v, session %q: %w", + index, streamID, sessionID, ErrEventsPurged) } - return slices.Clone(dl.data[index-dl.first:]), nil + return slices.Clone(dl.data[start-dl.first:]), nil } return func(yield func([]byte, error) bool) { @@ -326,26 +320,6 @@ func (s *MemoryEventStore) After(_ context.Context, sessionID string, streamID S } } -// StreamClosed implements [EventStore.StreamClosed]. -func (s *MemoryEventStore) StreamClosed(_ context.Context, sessionID string, streamID StreamID) error { - if sessionID == "" { - panic("empty sessionID") - } - - s.mu.Lock() - defer s.mu.Unlock() - - sm := s.store[sessionID] - dl := sm[streamID] - s.nBytes -= dl.size - delete(sm, streamID) - if len(sm) == 0 { - delete(s.store, sessionID) - } - s.validate() - return nil -} - // SessionClosed implements [EventStore.SessionClosed]. func (s *MemoryEventStore) SessionClosed(_ context.Context, sessionID string) error { s.mu.Lock() diff --git a/mcp/event_test.go b/mcp/event_test.go index 9b01555f..147a947a 100644 --- a/mcp/event_test.go +++ b/mcp/event_test.go @@ -105,7 +105,7 @@ func TestMemoryEventStoreState(t *testing.T) { ctx := context.Background() appendEvent := func(s *MemoryEventStore, sess string, str StreamID, data string) { - if _, err := s.AppendEvent(ctx, sess, str, []byte(data)); err != nil { + if err := s.Append(ctx, sess, str, []byte(data)); err != nil { t.Fatal(err) } } @@ -127,18 +127,6 @@ func TestMemoryEventStoreState(t *testing.T) { "S1 1 first=0 d1 d3; S1 2 first=0 d2; S2 8 first=0 d4", 8, }, - { - "stream close", - func(s *MemoryEventStore) { - appendEvent(s, "S1", 1, "d1") - appendEvent(s, "S1", 2, "d2") - appendEvent(s, "S1", 1, "d3") - appendEvent(s, "S2", 8, "d4") - s.StreamClosed(ctx, "S1", 1) - }, - "S1 2 first=0 d2; S2 8 first=0 d4", - 4, - }, { "session close", func(s *MemoryEventStore) { @@ -218,10 +206,10 @@ func TestMemoryEventStoreAfter(t *testing.T) { ctx := context.Background() s := NewMemoryEventStore(nil) s.SetMaxBytes(4) - s.AppendEvent(ctx, "S1", 1, []byte("d1")) - s.AppendEvent(ctx, "S1", 1, []byte("d2")) - s.AppendEvent(ctx, "S1", 1, []byte("d3")) - s.AppendEvent(ctx, "S1", 2, []byte("d4")) // will purge d1 + s.Append(ctx, "S1", 1, []byte("d1")) + s.Append(ctx, "S1", 1, []byte("d2")) + s.Append(ctx, "S1", 1, []byte("d3")) + s.Append(ctx, "S1", 2, []byte("d4")) // will purge d1 want := "S1 1 first=1 d2 d3; S1 2 first=0 d4" if got := s.debugString(); got != want { t.Fatalf("got state %q, want %q", got, want) @@ -234,10 +222,10 @@ func TestMemoryEventStoreAfter(t *testing.T) { want []string wantErr string // if non-empty, error should contain this string }{ - {"S1", 1, 0, nil, "purge"}, - {"S1", 1, 1, []string{"d2", "d3"}, ""}, - {"S1", 1, 2, []string{"d3"}, ""}, - {"S1", 2, 0, []string{"d4"}, ""}, + {"S1", 1, 0, []string{"d2", "d3"}, ""}, + {"S1", 1, 1, []string{"d3"}, ""}, + {"S1", 1, 2, nil, ""}, + {"S1", 2, 0, nil, ""}, {"S1", 3, 0, nil, "unknown stream ID"}, {"S2", 0, 0, nil, "unknown session ID"}, } { diff --git a/mcp/server.go b/mcp/server.go index 6b287ad7..e0f691dc 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -85,9 +85,11 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server { if opts.PageSize < 0 { panic(fmt.Errorf("invalid page size %d", opts.PageSize)) } + // TODO(jba): don't modify opts, modify Server.opts. if opts.PageSize == 0 { opts.PageSize = DefaultPageSize } + return &Server{ impl: impl, opts: *opts, diff --git a/mcp/streamable.go b/mcp/streamable.go index f7c6ed63..7dba4504 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -7,6 +7,7 @@ package mcp import ( "bytes" "context" + "errors" "fmt" "io" "math" @@ -132,7 +133,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } if session == nil { - s := NewStreamableServerTransport(randText()) + s := NewStreamableServerTransport(randText(), nil) server := h.getServer(req) // Pass req.Context() here, to allow middleware to add context values. // The context is detached in the jsonrpc2 library when handling the @@ -150,27 +151,40 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque session.ServeHTTP(w, req) } +type StreamableServerTransportOptions struct { + // Storage for events, to enable stream resumption. + // If nil, a [MemoryEventStore] with the default maximum size will be used. + EventStore EventStore +} + // NewStreamableServerTransport returns a new [StreamableServerTransport] with -// the given session ID. +// the given session ID and options. // The session ID must be globally unique, that is, different from any other // session ID anywhere, past and future. (We recommend using a crypto random number -// generator to produce one, as in [crypto/rand.Text].) +// generator to produce one, as with [crypto/rand.Text].) // // A StreamableServerTransport implements the server-side of the streamable // transport. -// -// TODO(rfindley): consider adding options here, to configure event storage -// policy. -func NewStreamableServerTransport(sessionID string) *StreamableServerTransport { - return &StreamableServerTransport{ - id: sessionID, - incoming: make(chan jsonrpc.Message, 10), - done: make(chan struct{}), - outgoingMessages: make(map[StreamID][]*streamableMsg), - signals: make(map[StreamID]chan struct{}), - requestStreams: make(map[jsonrpc.ID]StreamID), - streamRequests: make(map[StreamID]map[jsonrpc.ID]struct{}), +func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransportOptions) *StreamableServerTransport { + if opts == nil { + opts = &StreamableServerTransportOptions{} + } + t := &StreamableServerTransport{ + id: sessionID, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + outgoing: make(map[StreamID][][]byte), + signals: make(map[StreamID]chan struct{}), + requestStreams: make(map[jsonrpc.ID]StreamID), + streamRequests: make(map[StreamID]map[jsonrpc.ID]struct{}), } + if opts != nil { + t.opts = *opts + } + if t.opts.EventStore == nil { + t.opts.EventStore = NewMemoryEventStore(nil) + } + return t } func (t *StreamableServerTransport) SessionID() string { @@ -183,10 +197,10 @@ type StreamableServerTransport struct { nextStreamID atomic.Int64 // incrementing next stream ID id string + opts StreamableServerTransportOptions incoming chan jsonrpc.Message // messages from the client to the server mu sync.Mutex - // Sessions are closed exactly once. isDone bool done chan struct{} @@ -205,17 +219,14 @@ type StreamableServerTransport struct { // // TODO(rfindley): simplify. - // outgoingMessages is the collection of outgoingMessages messages, keyed by the logical + // outgoing is the collection of outgoing messages, keyed by the logical // stream ID where they should be delivered. // // streamID 0 is used for messages that don't correlate with an incoming // request. // - // Lifecycle: outgoingMessages persists for the duration of the session. - // - // TODO(rfindley): garbage collect this data. For now, we save all outgoingMessages - // messages for the lifespan of the transport. - outgoingMessages map[StreamID][]*streamableMsg + // Lifecycle: persists for the duration of the session. + outgoing map[StreamID][][]byte // signals maps a logical stream ID to a 1-buffered channel, owned by an // incoming HTTP request, that signals that there are messages available to @@ -301,16 +312,19 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) { // connID 0 corresponds to the default GET request. - id, nextIdx := StreamID(0), 0 + id := StreamID(0) + // By default, we haven't seen a last index. Since indices start at 0, we represent + // that by -1. This is incremented just before each event is written, in streamResponse + // around L407. + lastIdx := -1 if len(req.Header.Values("Last-Event-ID")) > 0 { eid := req.Header.Get("Last-Event-ID") var ok bool - id, nextIdx, ok = parseEventID(eid) + id, lastIdx, ok = parseEventID(eid) if !ok { http.Error(w, fmt.Sprintf("malformed Last-Event-ID %q", eid), http.StatusBadRequest) return } - nextIdx++ } t.mu.Lock() @@ -323,7 +337,7 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re t.signals[id] = signal t.mu.Unlock() - t.streamResponse(w, req, id, nextIdx, signal) + t.streamResponse(w, req, id, lastIdx, signal) } func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) { @@ -376,26 +390,33 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R // TODO(rfindley): consider optimizing for a single incoming request, by // responding with application/json when there is only a single message in // the response. - t.streamResponse(w, req, id, 0, signal) + t.streamResponse(w, req, id, -1, signal) } -func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *http.Request, id StreamID, nextIndex int, signal chan struct{}) { +// lastIndex is the index of the last seen event if resuming, else -1. +func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *http.Request, id StreamID, lastIndex int, signal chan struct{}) { defer func() { t.mu.Lock() delete(t.signals, id) t.mu.Unlock() }() - // Stream resumption: adjust outgoing index based on what the user says - // they've received. - if nextIndex > 0 { - t.mu.Lock() - // Clamp nextIndex to outgoing messages. - outgoing := t.outgoingMessages[id] - if nextIndex > len(outgoing) { - nextIndex = len(outgoing) + writes := 0 + + // write one event containing data. + write := func(data []byte) bool { + lastIndex++ + e := Event{ + Name: "message", + ID: formatEventID(id, lastIndex), + Data: data, } - t.mu.Unlock() + if _, err := writeEvent(w, e); err != nil { + // Connection closed or broken. + return false + } + writes++ + return true } w.Header().Set(sessionIDHeader, t.id) @@ -403,37 +424,53 @@ func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *h w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Connection", "keep-alive") - writes := 0 + if lastIndex >= 0 { + // Resume. + for data, err := range t.opts.EventStore.After(req.Context(), t.SessionID(), id, lastIndex) { + if err != nil { + // TODO: reevaluate these status codes. + // Maybe distinguish between storage errors, which are 500s, and missing + // session or stream ID--can these arise from bad input? + status := http.StatusInternalServerError + if errors.Is(err, ErrEventsPurged) { + status = http.StatusInsufficientStorage + } + http.Error(w, err.Error(), status) + return + } + // The iterator yields events beginning just after lastIndex, or it would have + // yielded an error. + if !write(data) { + return + } + } + } + stream: + // Repeatedly collect pending outgoing events and send them. for { - // Send outgoing messages t.mu.Lock() - outgoing := t.outgoingMessages[id][nextIndex:] + outgoing := t.outgoing[id] + t.outgoing[id] = nil t.mu.Unlock() - for _, msg := range outgoing { - if _, err := writeEvent(w, msg.event); err != nil { - // Connection closed or broken. + for _, data := range outgoing { + if err := t.opts.EventStore.Append(req.Context(), t.id, id, data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if !write(data) { return } - writes++ - nextIndex++ } t.mu.Lock() nOutstanding := len(t.streamRequests[id]) - nOutgoing := len(t.outgoingMessages[id]) t.mu.Unlock() - // If all requests have been handled and replied to, we can terminate this - // connection. However, in the case of a sequencing violation from the server - // (a send on the request context after the request has been handled), we - // loop until we've written all messages. - // - // TODO(rfindley): should we instead refuse to send messages after the last - // response? Decide, write a test, and change the behavior. - if nextIndex < nOutgoing { - continue // more to send - } + // If all requests have been handled and replied to, we should terminate this connection. + // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." + // §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server + // TODO(jba): why not terminate regardless of http method? if req.Method == http.MethodPost && nOutstanding == 0 { if writes == 0 { // Spec: If the server accepts the input, the server MUST return HTTP @@ -444,8 +481,9 @@ stream: } select { - case <-signal: - case <-t.done: + case <-signal: // there are new outgoing messages + // return to top of loop + case <-t.done: // session is closed if writes == 0 { http.Error(w, "session terminated", http.StatusGone) } @@ -552,15 +590,7 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa forConn = 0 } - idx := len(t.outgoingMessages[forConn]) - t.outgoingMessages[forConn] = append(t.outgoingMessages[forConn], &streamableMsg{ - idx: idx, - event: Event{ - Name: "message", - ID: formatEventID(forConn, idx), - Data: data, - }, - }) + t.outgoing[forConn] = append(t.outgoing[forConn], data) if replyTo.IsValid() { // Once we've put the reply on the queue, it's no longer outstanding. delete(t.streamRequests[forConn], replyTo) @@ -586,6 +616,7 @@ func (t *StreamableServerTransport) Close() error { if !t.isDone { t.isDone = true close(t.done) + return t.opts.EventStore.SessionClosed(context.TODO(), t.id) } return nil } From bcbb31f07241cfe7deb2c985164c0e7a0cc6fc92 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Wed, 23 Jul 2025 15:23:50 -0400 Subject: [PATCH 038/221] mcp: add resource subscriptions (#138) This CL adds the ability for clients to subscribe and receive updates for resources as described in https://modelcontextprotocol.io/specification/2025-06-18/server/resources#subscriptions Fixes: #23 --- design/design.md | 17 +++++++++-- mcp/client.go | 20 +++++++++++++ mcp/mcp_test.go | 44 ++++++++++++++++++++++++++- mcp/protocol.go | 32 ++++++++++++++++++++ mcp/server.go | 74 +++++++++++++++++++++++++++++++++++++++++++++- mcp/server_test.go | 33 +++++++++++++++++++-- 6 files changed, 214 insertions(+), 6 deletions(-) diff --git a/design/design.md b/design/design.md index 8804292e..bfabeac7 100644 --- a/design/design.md +++ b/design/design.md @@ -748,13 +748,26 @@ Server sessions also support the spec methods `ListResources` and `ListResourceT #### Subscriptions -ClientSessions can manage change notifications on particular resources: +##### Client-Side Usage + +Use the Subscribe and Unsubscribe methods on a ClientSession to start or stop receiving updates for a specific resource. ```go func (*ClientSession) Subscribe(context.Context, *SubscribeParams) error func (*ClientSession) Unsubscribe(context.Context, *UnsubscribeParams) error ``` +To process incoming update notifications, you must provide a ResourceUpdatedHandler in your ClientOptions. The SDK calls this function automatically whenever the server sends a notification for a resource you're subscribed to. + +```go +type ClientOptions struct { + ... + ResourceUpdatedHandler func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams) +} +``` + +##### Server-Side Implementation + The server does not implement resource subscriptions. It passes along subscription requests to the user, and supplies a method to notify clients of changes. It tracks which sessions have subscribed to which resources so the user doesn't have to. If a server author wants to support resource subscriptions, they must provide handlers to be called when clients subscribe and unsubscribe. It is an error to provide only one of these handlers. @@ -772,7 +785,7 @@ type ServerOptions struct { User code should call `ResourceUpdated` when a subscribed resource changes. ```go -func (*Server) ResourceUpdated(context.Context, *ResourceUpdatedNotification) error +func (*Server) ResourceUpdated(context.Context, *ResourceUpdatedNotificationParams) error ``` The server routes these notifications to the server sessions that subscribed to the resource. diff --git a/mcp/client.go b/mcp/client.go index b48ad7a1..b386294c 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -60,6 +60,7 @@ type ClientOptions struct { ToolListChangedHandler func(context.Context, *ClientSession, *ToolListChangedParams) PromptListChangedHandler func(context.Context, *ClientSession, *PromptListChangedParams) ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams) + ResourceUpdatedHandler func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams) LoggingMessageHandler func(context.Context, *ClientSession, *LoggingMessageParams) ProgressNotificationHandler func(context.Context, *ClientSession, *ProgressNotificationParams) // If non-zero, defines an interval for regular "ping" requests. @@ -293,6 +294,7 @@ var clientMethodInfos = map[string]methodInfo{ notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler)), notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler)), notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler)), + notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler)), notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler)), notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler)), } @@ -386,6 +388,20 @@ func (cs *ClientSession) Complete(ctx context.Context, params *CompleteParams) ( return handleSend[*CompleteResult](ctx, cs, methodComplete, orZero[Params](params)) } +// Subscribe sends a "resources/subscribe" request to the server, asking for +// notifications when the specified resource changes. +func (cs *ClientSession) Subscribe(ctx context.Context, params *SubscribeParams) error { + _, err := handleSend[*emptyResult](ctx, cs, methodSubscribe, orZero[Params](params)) + return err +} + +// Unsubscribe sends a "resources/unsubscribe" request to the server, cancelling +// a previous subscription. +func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribeParams) error { + _, err := handleSend[*emptyResult](ctx, cs, methodUnsubscribe, orZero[Params](params)) + return err +} + func (c *Client) callToolChangedHandler(ctx context.Context, s *ClientSession, params *ToolListChangedParams) (Result, error) { return callNotificationHandler(ctx, c.opts.ToolListChangedHandler, s, params) } @@ -398,6 +414,10 @@ func (c *Client) callResourceChangedHandler(ctx context.Context, s *ClientSessio return callNotificationHandler(ctx, c.opts.ResourceListChangedHandler, s, params) } +func (c *Client) callResourceUpdatedHandler(ctx context.Context, s *ClientSession, params *ResourceUpdatedNotificationParams) (Result, error) { + return callNotificationHandler(ctx, c.opts.ResourceUpdatedHandler, s, params) +} + func (c *Client) callLoggingHandler(ctx context.Context, cs *ClientSession, params *LoggingMessageParams) (Result, error) { if h := c.opts.LoggingMessageHandler; h != nil { h(ctx, cs, params) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 7da2b857..032181a1 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -60,7 +60,7 @@ func TestEndToEnd(t *testing.T) { // Channels to check if notification callbacks happened. notificationChans := map[string]chan int{} - for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client"} { + for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client", "resource_updated", "subscribe", "unsubscribe"} { notificationChans[name] = make(chan int, 1) } waitForNotification := func(t *testing.T, name string) { @@ -78,6 +78,14 @@ func TestEndToEnd(t *testing.T) { ProgressNotificationHandler: func(context.Context, *ServerSession, *ProgressNotificationParams) { notificationChans["progress_server"] <- 0 }, + SubscribeHandler: func(context.Context, *SubscribeParams) error { + notificationChans["subscribe"] <- 0 + return nil + }, + UnsubscribeHandler: func(context.Context, *UnsubscribeParams) error { + notificationChans["unsubscribe"] <- 0 + return nil + }, } s := NewServer(testImpl, sopts) AddTool(s, &Tool{ @@ -128,6 +136,9 @@ func TestEndToEnd(t *testing.T) { ProgressNotificationHandler: func(context.Context, *ClientSession, *ProgressNotificationParams) { notificationChans["progress_client"] <- 0 }, + ResourceUpdatedHandler: func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams) { + notificationChans["resource_updated"] <- 0 + }, } c := NewClient(testImpl, opts) rootAbs, err := filepath.Abs(filepath.FromSlash("testdata/files")) @@ -421,6 +432,37 @@ func TestEndToEnd(t *testing.T) { waitForNotification(t, "progress_server") }) + t.Run("resource_subscriptions", func(t *testing.T) { + err := cs.Subscribe(ctx, &SubscribeParams{ + URI: "test", + }) + if err != nil { + t.Fatal(err) + } + waitForNotification(t, "subscribe") + s.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{ + URI: "test", + }) + waitForNotification(t, "resource_updated") + err = cs.Unsubscribe(ctx, &UnsubscribeParams{ + URI: "test", + }) + if err != nil { + t.Fatal(err) + } + waitForNotification(t, "unsubscribe") + + // Verify the client does not receive the update after unsubscribing. + s.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{ + URI: "test", + }) + select { + case <-notificationChans["resource_updated"]: + t.Fatalf("resource updated after unsubscription") + case <-time.After(time.Second): + } + }) + // Disconnect. cs.Close() clientWG.Wait() diff --git a/mcp/protocol.go b/mcp/protocol.go index 4f47c961..00dcd14d 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -859,6 +859,38 @@ type ToolListChangedParams struct { func (x *ToolListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } +// Sent from the client to request resources/updated notifications from the +// server whenever a particular resource changes. +type SubscribeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to subscribe to. + URI string `json:"uri"` +} + +// Sent from the client to request cancellation of resources/updated +// notifications from the server. This should follow a previous +// resources/subscribe request. +type UnsubscribeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to unsubscribe from. + URI string `json:"uri"` +} + +// A notification from the server to the client, informing it that a resource +// has changed and may need to be read again. This should only be sent if the +// client previously sent a resources/subscribe request. +type ResourceUpdatedNotificationParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to. + URI string `json:"uri"` +} + // TODO(jba): add CompleteRequest and related types. // TODO(jba): add ElicitRequest and related types. diff --git a/mcp/server.go b/mcp/server.go index e0f691dc..75e66dbc 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -13,6 +13,7 @@ import ( "fmt" "iter" "log" + "maps" "net/url" "path/filepath" "slices" @@ -43,6 +44,7 @@ type Server struct { sessions []*ServerSession sendingMethodHandler_ MethodHandler[*ServerSession] receivingMethodHandler_ MethodHandler[*ServerSession] + resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool } // ServerOptions is used to configure behavior of the server. @@ -64,6 +66,10 @@ type ServerOptions struct { // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. KeepAlive time.Duration + // Function called when a client session subscribes to a resource. + SubscribeHandler func(context.Context, *SubscribeParams) error + // Function called when a client session unsubscribes from a resource. + UnsubscribeHandler func(context.Context, *UnsubscribeParams) error } // NewServer creates a new MCP server. The resulting server has no features: @@ -89,7 +95,12 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server { if opts.PageSize == 0 { opts.PageSize = DefaultPageSize } - + if opts.SubscribeHandler != nil && opts.UnsubscribeHandler == nil { + panic("SubscribeHandler requires UnsubscribeHandler") + } + if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil { + panic("UnsubscribeHandler requires SubscribeHandler") + } return &Server{ impl: impl, opts: *opts, @@ -99,6 +110,7 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server { resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }), sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession], receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], + resourceSubscriptions: make(map[string]map[*ServerSession]bool), } } @@ -225,6 +237,9 @@ func (s *Server) capabilities() *serverCapabilities { } if s.resources.len() > 0 || s.resourceTemplates.len() > 0 { caps.Resources = &resourceCapabilities{ListChanged: true} + if s.opts.SubscribeHandler != nil { + caps.Resources.Subscribe = true + } } return caps } @@ -428,6 +443,57 @@ func fileResourceHandler(dir string) ResourceHandler { } } +// ResourceUpdated sends a notification to all clients that have subscribed to the +// resource specified in params. This method is the primary way for a +// server author to signal that a resource has changed. +func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNotificationParams) error { + s.mu.Lock() + subscribedSessions := s.resourceSubscriptions[params.URI] + sessions := slices.Collect(maps.Keys(subscribedSessions)) + s.mu.Unlock() + notifySessions(sessions, notificationResourceUpdated, params) + return nil +} + +func (s *Server) subscribe(ctx context.Context, ss *ServerSession, params *SubscribeParams) (*emptyResult, error) { + if s.opts.SubscribeHandler == nil { + return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound) + } + if err := s.opts.SubscribeHandler(ctx, params); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + if s.resourceSubscriptions[params.URI] == nil { + s.resourceSubscriptions[params.URI] = make(map[*ServerSession]bool) + } + s.resourceSubscriptions[params.URI][ss] = true + + return &emptyResult{}, nil +} + +func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *UnsubscribeParams) (*emptyResult, error) { + if s.opts.UnsubscribeHandler == nil { + return nil, jsonrpc2.ErrMethodNotFound + } + + if err := s.opts.UnsubscribeHandler(ctx, params); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + if subscribedSessions, ok := s.resourceSubscriptions[params.URI]; ok { + delete(subscribedSessions, ss) + if len(subscribedSessions) == 0 { + delete(s.resourceSubscriptions, params.URI) + } + } + + return &emptyResult{}, nil +} + // Run runs the server over the given transport, which must be persistent. // // Run blocks until the client terminates the connection or the provided @@ -475,6 +541,10 @@ func (s *Server) disconnect(cc *ServerSession) { s.sessions = slices.DeleteFunc(s.sessions, func(cc2 *ServerSession) bool { return cc2 == cc }) + + for _, subscribedSessions := range s.resourceSubscriptions { + delete(subscribedSessions, cc) + } } // Connect connects the MCP server over the given transport and starts handling @@ -616,6 +686,8 @@ var serverMethodInfos = map[string]methodInfo{ methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates)), methodReadResource: newMethodInfo(serverMethod((*Server).readResource)), methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel)), + methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe)), + methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe)), notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler)), notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler)), notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler)), diff --git a/mcp/server_test.go b/mcp/server_test.go index d4243d7c..bb539772 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -5,6 +5,7 @@ package mcp import ( + "context" "log" "slices" "testing" @@ -232,6 +233,7 @@ func TestServerCapabilities(t *testing.T) { testCases := []struct { name string configureServer func(s *Server) + serverOpts ServerOptions wantCapabilities *serverCapabilities }{ { @@ -275,6 +277,25 @@ func TestServerCapabilities(t *testing.T) { Resources: &resourceCapabilities{ListChanged: true}, }, }, + { + name: "With resource subscriptions", + configureServer: func(s *Server) { + s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) + }, + serverOpts: ServerOptions{ + SubscribeHandler: func(ctx context.Context, sp *SubscribeParams) error { + return nil + }, + UnsubscribeHandler: func(ctx context.Context, up *UnsubscribeParams) error { + return nil + }, + }, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Resources: &resourceCapabilities{ListChanged: true, Subscribe: true}, + }, + }, { name: "With tools", configureServer: func(s *Server) { @@ -294,11 +315,19 @@ func TestServerCapabilities(t *testing.T) { s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) s.AddTool(&Tool{Name: "t"}, nil) }, + serverOpts: ServerOptions{ + SubscribeHandler: func(ctx context.Context, sp *SubscribeParams) error { + return nil + }, + UnsubscribeHandler: func(ctx context.Context, up *UnsubscribeParams) error { + return nil + }, + }, wantCapabilities: &serverCapabilities{ Completions: &completionCapabilities{}, Logging: &loggingCapabilities{}, Prompts: &promptCapabilities{ListChanged: true}, - Resources: &resourceCapabilities{ListChanged: true}, + Resources: &resourceCapabilities{ListChanged: true, Subscribe: true}, Tools: &toolCapabilities{ListChanged: true}, }, }, @@ -306,7 +335,7 @@ func TestServerCapabilities(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - server := NewServer(testImpl, nil) + server := NewServer(testImpl, &tc.serverOpts) tc.configureServer(server) gotCapabilities := server.capabilities() if diff := cmp.Diff(tc.wantCapabilities, gotCapabilities); diff != "" { From 94b88a26e033024480c1d6fdbcde34aa5265342d Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Wed, 23 Jul 2025 15:40:57 -0400 Subject: [PATCH 039/221] mcp/server: fix nil PingParams in server.Ping (#163) Use OrZero in case PingParams is nil to avoid typed nil. Fixes: #162 --- mcp/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp/server.go b/mcp/server.go index 75e66dbc..6433c5b5 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -610,7 +610,7 @@ func (ss *ServerSession) ID() string { // Ping pings the client. func (ss *ServerSession) Ping(ctx context.Context, params *PingParams) error { - _, err := handleSend[*emptyResult](ctx, ss, methodPing, params) + _, err := handleSend[*emptyResult](ctx, ss, methodPing, orZero[Params](params)) return err } From 8dd9a819bb283849de91c839f6f1734b86a9fc1b Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 23 Jul 2025 16:02:00 -0400 Subject: [PATCH 040/221] mcp: panic if AddTool method is given a nil InputSchema (#157) The spec requires an input schema for every tool. --- mcp/mcp_test.go | 2 +- mcp/server.go | 14 +++++++++----- mcp/server_test.go | 6 ++++-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 032181a1..3c222eb4 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -659,7 +659,7 @@ func TestCancellation(t *testing.T) { return nil, nil } _, cs := basicConnection(t, func(s *Server) { - s.AddTool(&Tool{Name: "slow"}, slowRequest) + s.AddTool(&Tool{Name: "slow", InputSchema: &jsonschema.Schema{}}, slowRequest) }) defer cs.Close() diff --git a/mcp/server.go b/mcp/server.go index 6433c5b5..960d35f2 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -12,7 +12,6 @@ import ( "encoding/json" "fmt" "iter" - "log" "maps" "net/url" "path/filepath" @@ -132,13 +131,18 @@ func (s *Server) RemovePrompts(names ...string) { } // AddTool adds a [Tool] to the server, or replaces one with the same name. -// The tool's input schema must be non-nil. // The Tool argument must not be modified after this call. +// +// The tool's input schema must be non-nil. For a tool that takes no input, +// or one where any input is valid, set [Tool.InputSchema] to the empty schema, +// &jsonschema.Schema{}. func (s *Server) AddTool(t *Tool, h ToolHandler) { - // TODO(jba): This is a breaking behavior change. Add before v0.2.0? if t.InputSchema == nil { - log.Printf("mcp: tool %q has a nil input schema. This will panic in a future release.", t.Name) - // panic(fmt.Sprintf("adding tool %q: nil input schema", t.Name)) + // This prevents the tool author from forgetting to write a schema where + // one should be provided. If we papered over this by supplying the empty + // schema, then every input would be validated and the problem wouldn't be + // discovered until runtime, when the LLM sent bad data. + panic(fmt.Sprintf("adding tool %q: nil input schema", t.Name)) } if err := addToolErr(s, t, h); err != nil { panic(err) diff --git a/mcp/server_test.go b/mcp/server_test.go index bb539772..9a36c962 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/modelcontextprotocol/go-sdk/jsonschema" ) type testItem struct { @@ -230,6 +231,7 @@ func TestServerPaginateVariousPageSizes(t *testing.T) { } func TestServerCapabilities(t *testing.T) { + tool := &Tool{Name: "t", InputSchema: &jsonschema.Schema{}} testCases := []struct { name string configureServer func(s *Server) @@ -299,7 +301,7 @@ func TestServerCapabilities(t *testing.T) { { name: "With tools", configureServer: func(s *Server) { - s.AddTool(&Tool{Name: "t"}, nil) + s.AddTool(tool, nil) }, wantCapabilities: &serverCapabilities{ Completions: &completionCapabilities{}, @@ -313,7 +315,7 @@ func TestServerCapabilities(t *testing.T) { s.AddPrompt(&Prompt{Name: "p"}, nil) s.AddResource(&Resource{URI: "file:///r"}, nil) s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) - s.AddTool(&Tool{Name: "t"}, nil) + s.AddTool(tool, nil) }, serverOpts: ServerOptions{ SubscribeHandler: func(ctx context.Context, sp *SubscribeParams) error { From c879ac3c61be9320d36254fb29ac19261a9ad559 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 24 Jul 2025 10:35:14 -0400 Subject: [PATCH 041/221] mcp: gracefully handle a nil server in handlers (#164) If the getServer function passed to NewSSEHandler or NewStreamableHTTPHandler returns nil, serve a 400 instead of panicking. Fixes #161. --- examples/sse/main.go | 2 +- mcp/sse.go | 15 ++++++++++----- mcp/streamable.go | 6 ++++++ 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/examples/sse/main.go b/examples/sse/main.go index 99b83e65..3289ee15 100644 --- a/examples/sse/main.go +++ b/examples/sse/main.go @@ -53,5 +53,5 @@ func main() { return nil } }) - http.ListenAndServe(*httpAddr, handler) + log.Fatal(http.ListenAndServe(*httpAddr, handler)) } diff --git a/mcp/sse.go b/mcp/sse.go index 4051f45c..92a7ca1a 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -55,13 +55,14 @@ type SSEHandler struct { // Sessions are created when the client issues a GET request to the server, // which must accept text/event-stream responses (server-sent events). // For each such request, a new [SSEServerTransport] is created with a distinct -// messages endpoint, and connected to the server returned by getServer. It is -// up to the user whether getServer returns a distinct [Server] for each new -// request, or reuses an existing server. -// +// messages endpoint, and connected to the server returned by getServer. // The SSEHandler also handles requests to the message endpoints, by // delegating them to the relevant server transport. // +// The getServer function may return a distinct [Server] for each new +// request, or reuse an existing server. If it returns nil, the handler +// will return a 400 Bad Request. +// // TODO(rfindley): add options. func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler { return &SSEHandler{ @@ -208,8 +209,12 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { h.mu.Unlock() }() - // TODO(hxjiang): getServer returns nil will panic. server := h.getServer(req) + if server == nil { + // The getServer argument to NewSSEHandler returned nil. + http.Error(w, "no server available", http.StatusBadRequest) + return + } ss, err := server.Connect(req.Context(), transport) if err != nil { http.Error(w, "connection failed", http.StatusInternalServerError) diff --git a/mcp/streamable.go b/mcp/streamable.go index 7dba4504..e5ffa642 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -50,6 +50,7 @@ type StreamableHTTPOptions struct { // // The getServer function is used to create or look up servers for new // sessions. It is OK for getServer to return the same server multiple times. +// If getServer returns nil, a 400 Bad Request will be served. func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *StreamableHTTPOptions) *StreamableHTTPHandler { return &StreamableHTTPHandler{ getServer: getServer, @@ -135,6 +136,11 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque if session == nil { s := NewStreamableServerTransport(randText(), nil) server := h.getServer(req) + if server == nil { + // The getServer argument to NewStreamableHTTPHandler returned nil. + http.Error(w, "no server available", http.StatusBadRequest) + return + } // Pass req.Context() here, to allow middleware to add context values. // The context is detached in the jsonrpc2 library when handling the // long-running stream. From 1ea7df32677ac8348362401abcb4fb7600c4f44b Mon Sep 17 00:00:00 2001 From: Adam Koszek Date: Sun, 27 Jul 2025 05:23:52 -0700 Subject: [PATCH 042/221] mcp/examples: Make example more usable. (#160) I'm making this example more useful for debugging: - `-port` is supported - `-host` is supported - error handling is there. --- examples/sse/main.go | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/examples/sse/main.go b/examples/sse/main.go index 3289ee15..59412b15 100644 --- a/examples/sse/main.go +++ b/examples/sse/main.go @@ -7,13 +7,18 @@ package main import ( "context" "flag" + "fmt" "log" "net/http" + "os" "github.com/modelcontextprotocol/go-sdk/mcp" ) -var httpAddr = flag.String("http", "", "use SSE HTTP at this address") +var ( + host = flag.String("host", "localhost", "host to listen on") + port = flag.String("port", "8080", "port to listen on") +) type SayHiParams struct { Name string `json:"name"` @@ -28,11 +33,19 @@ func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParam } func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s [options]\n\n", os.Args[0]) + fmt.Fprintf(os.Stderr, "This program runs MCP servers over SSE HTTP.\n\n") + fmt.Fprintf(os.Stderr, "Options:\n") + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, "\nEndpoints:\n") + fmt.Fprintf(os.Stderr, " /greeter1 - Greeter 1 service\n") + fmt.Fprintf(os.Stderr, " /greeter2 - Greeter 2 service\n") + os.Exit(1) + } flag.Parse() - if httpAddr == nil || *httpAddr == "" { - log.Fatal("http address not set") - } + addr := fmt.Sprintf("%s:%s", *host, *port) server1 := mcp.NewServer(&mcp.Implementation{Name: "greeter1"}, nil) mcp.AddTool(server1, &mcp.Tool{Name: "greet1", Description: "say hi"}, SayHi) @@ -40,7 +53,7 @@ func main() { server2 := mcp.NewServer(&mcp.Implementation{Name: "greeter2"}, nil) mcp.AddTool(server2, &mcp.Tool{Name: "greet2", Description: "say hello"}, SayHi) - log.Printf("MCP servers serving at %s", *httpAddr) + log.Printf("MCP servers serving at %s", addr) handler := mcp.NewSSEHandler(func(request *http.Request) *mcp.Server { url := request.URL.Path log.Printf("Handling request for URL %s\n", url) @@ -53,5 +66,5 @@ func main() { return nil } }) - log.Fatal(http.ListenAndServe(*httpAddr, handler)) + log.Fatal(http.ListenAndServe(addr, handler)) } From e8c6e0384177795db72b5fd6bc3379d391f63596 Mon Sep 17 00:00:00 2001 From: cryo Date: Mon, 28 Jul 2025 19:51:49 +0800 Subject: [PATCH 043/221] fix comment (#180) update comment for func `clientMethod` --- mcp/shared.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp/shared.go b/mcp/shared.go index fef20946..49dd2b10 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -203,7 +203,7 @@ func serverMethod[P Params, R Result]( } } -// clientMethod is glue for creating a typedMethodHandler from a method on Server. +// clientMethod is glue for creating a typedMethodHandler from a method on Client. func clientMethod[P Params, R Result]( f func(*Client, context.Context, *ClientSession, P) (R, error), ) typedMethodHandler[*ClientSession, P, R] { From 86f71fcb70877379279818fad2f40cc314e34068 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 28 Jul 2025 10:27:02 -0400 Subject: [PATCH 044/221] mcp: add stream type (#171) Consolidate several maps into a single struct. Simplifies the code, for the most part. --- mcp/streamable.go | 152 ++++++++++++++++++++++++++-------------------- 1 file changed, 86 insertions(+), 66 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index e5ffa642..e15c6a69 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -179,11 +179,10 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp id: sessionID, incoming: make(chan jsonrpc.Message, 10), done: make(chan struct{}), - outgoing: make(map[StreamID][][]byte), - signals: make(map[StreamID]chan struct{}), + streams: make(map[StreamID]*stream), requestStreams: make(map[jsonrpc.ID]StreamID), - streamRequests: make(map[StreamID]map[jsonrpc.ID]struct{}), } + t.streams[0] = newStream(0) if opts != nil { t.opts = *opts } @@ -219,59 +218,66 @@ type StreamableServerTransport struct { // perform the accounting described below when incoming HTTP requests are // handled. // - // The accounting is complicated. It is tempting to merge some of the maps - // below, but they each have different lifecycles, as indicated by Lifecycle: - // comments. - // // TODO(rfindley): simplify. - // outgoing is the collection of outgoing messages, keyed by the logical - // stream ID where they should be delivered. + // streams holds the logical streams for this session, keyed by their ID. + streams map[StreamID]*stream + + // requestStreams maps incoming requests to their logical stream ID. // - // streamID 0 is used for messages that don't correlate with an incoming - // request. + // Lifecycle: requestStreams persists for the duration of the session. // - // Lifecycle: persists for the duration of the session. - outgoing map[StreamID][][]byte + // TODO(rfindley): clean up once requests are handled. + requestStreams map[jsonrpc.ID]StreamID +} + +// A stream is a single logical stream of SSE events within a server session. +// A stream begins with a client request, or with a client GET that has +// no Last-Event-ID header. +// A stream ends only when its session ends; we cannot determine its end otherwise, +// since a client may send a GET with a Last-Event-ID that references the stream +// at any time. +type stream struct { + // id is the logical ID for the stream, unique within a session. + // ID 0 is used for messages that don't correlate with an incoming request. + id StreamID - // signals maps a logical stream ID to a 1-buffered channel, owned by an + // These mutable fields are protected by the mutex of the corresponding StreamableServerTransport. + + // outgoing is the list of outgoing messages, enqueued by server methods that + // write notifications and responses, and dequeued by streamResponse. + outgoing [][]byte + + // signal is a 1-buffered channel, owned by an // incoming HTTP request, that signals that there are messages available to - // write into the HTTP response. Signals guarantees that at most one HTTP + // write into the HTTP response. This guarantees that at most one HTTP // response can receive messages for a logical stream. After claiming // the stream, incoming requests should read from outgoing, to ensure // that no new messages are missed. // - // Lifecycle: signals persists for the duration of an HTTP POST or GET + // Lifecycle: persists for the duration of an HTTP POST or GET // request for the given streamID. - signals map[StreamID]chan struct{} + signal chan struct{} - // requestStreams maps incoming requests to their logical stream ID. - // - // Lifecycle: requestStreams persists for the duration of the session. + // streamRequests is the set of unanswered incoming RPCs for the stream. // - // TODO(rfindley): clean up once requests are handled. - requestStreams map[jsonrpc.ID]StreamID - - // streamRequests tracks the set of unanswered incoming RPCs for each logical - // stream. - // - // When the server has responded to each request, the stream should be - // closed. - // - // Lifecycle: streamRequests values persist as until the requests have been + // Lifecycle: requests values persist until the requests have been // replied to by the server. Notably, NOT until they are sent to an HTTP // response, as delivery is not guaranteed. - streamRequests map[StreamID]map[jsonrpc.ID]struct{} + requests map[jsonrpc.ID]struct{} } -type StreamID int64 - -// a streamableMsg is an SSE event with an index into its logical stream. -type streamableMsg struct { - idx int - event Event +func newStream(id StreamID) *stream { + return &stream{ + id: id, + requests: make(map[jsonrpc.ID]struct{}), + } } +// A StreamID identifies a stream of SSE events. It is unique within the stream's +// [ServerSession]. +type StreamID int64 + // Connect implements the [Transport] interface. // // TODO(rfindley): Connect should return a new object. @@ -334,16 +340,21 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re } t.mu.Lock() - if _, ok := t.signals[id]; ok { + stream, ok := t.streams[id] + if !ok { + http.Error(w, "unknown stream", http.StatusBadRequest) + t.mu.Unlock() + return + } + if stream.signal != nil { http.Error(w, "stream ID conflicts with ongoing stream", http.StatusBadRequest) t.mu.Unlock() return } - signal := make(chan struct{}, 1) - t.signals[id] = signal + stream.signal = make(chan struct{}, 1) t.mu.Unlock() - t.streamResponse(w, req, id, lastIdx, signal) + t.streamResponse(stream, w, req, lastIdx) } func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) { @@ -375,17 +386,17 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R } // Update accounting for this request. - id := StreamID(t.nextStreamID.Add(1)) - signal := make(chan struct{}, 1) + stream := newStream(StreamID(t.nextStreamID.Add(1))) t.mu.Lock() + t.streams[stream.id] = stream if len(requests) > 0 { - t.streamRequests[id] = make(map[jsonrpc.ID]struct{}) + stream.requests = make(map[jsonrpc.ID]struct{}) } for reqID := range requests { - t.requestStreams[reqID] = id - t.streamRequests[id][reqID] = struct{}{} + t.requestStreams[reqID] = stream.id + stream.requests[reqID] = struct{}{} } - t.signals[id] = signal + stream.signal = make(chan struct{}, 1) t.mu.Unlock() // Publish incoming messages. @@ -396,17 +407,24 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R // TODO(rfindley): consider optimizing for a single incoming request, by // responding with application/json when there is only a single message in // the response. - t.streamResponse(w, req, id, -1, signal) + t.streamResponse(stream, w, req, -1) } // lastIndex is the index of the last seen event if resuming, else -1. -func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *http.Request, id StreamID, lastIndex int, signal chan struct{}) { +func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) { defer func() { t.mu.Lock() - delete(t.signals, id) + stream.signal = nil t.mu.Unlock() }() + t.mu.Lock() + // Although there is a gap in locking between when stream.signal is set and here, + // it cannot change, because it is changed only when non-nil, and it is only + // set to nil in the defer above. + signal := stream.signal + t.mu.Unlock() + writes := 0 // write one event containing data. @@ -414,11 +432,12 @@ func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *h lastIndex++ e := Event{ Name: "message", - ID: formatEventID(id, lastIndex), + ID: formatEventID(stream.id, lastIndex), Data: data, } if _, err := writeEvent(w, e); err != nil { // Connection closed or broken. + // TODO: log when we add server-side logging. return false } writes++ @@ -432,7 +451,7 @@ func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *h if lastIndex >= 0 { // Resume. - for data, err := range t.opts.EventStore.After(req.Context(), t.SessionID(), id, lastIndex) { + for data, err := range t.opts.EventStore.After(req.Context(), t.SessionID(), stream.id, lastIndex) { if err != nil { // TODO: reevaluate these status codes. // Maybe distinguish between storage errors, which are 500s, and missing @@ -456,12 +475,12 @@ stream: // Repeatedly collect pending outgoing events and send them. for { t.mu.Lock() - outgoing := t.outgoing[id] - t.outgoing[id] = nil + outgoing := stream.outgoing + stream.outgoing = nil t.mu.Unlock() for _, data := range outgoing { - if err := t.opts.EventStore.Append(req.Context(), t.id, id, data); err != nil { + if err := t.opts.EventStore.Append(req.Context(), t.SessionID(), stream.id, data); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -471,7 +490,7 @@ stream: } t.mu.Lock() - nOutstanding := len(t.streamRequests[id]) + nOutstanding := len(stream.requests) t.mu.Unlock() // If all requests have been handled and replied to, we should terminate this connection. // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." @@ -585,30 +604,31 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa t.mu.Lock() defer t.mu.Unlock() if t.isDone { - return fmt.Errorf("session is closed") // TODO: should this be EOF? + return errors.New("session is closed") } - if _, ok := t.streamRequests[forConn]; !ok && forConn != 0 { + stream := t.streams[forConn] + if stream == nil { + return fmt.Errorf("no stream with ID %d", forConn) + } + if len(stream.requests) == 0 && forConn != 0 { // No outstanding requests for this connection, which means it is logically // done. This is a sequencing violation from the server, so we should report // a side-channel error here. Put the message on the general queue to avoid // dropping messages. - forConn = 0 + stream = t.streams[0] } - t.outgoing[forConn] = append(t.outgoing[forConn], data) + stream.outgoing = append(stream.outgoing, data) if replyTo.IsValid() { // Once we've put the reply on the queue, it's no longer outstanding. - delete(t.streamRequests[forConn], replyTo) - if len(t.streamRequests[forConn]) == 0 { - delete(t.streamRequests, forConn) - } + delete(stream.requests, replyTo) } // Signal work. - if c, ok := t.signals[forConn]; ok { + if stream.signal != nil { select { - case c <- struct{}{}: + case stream.signal <- struct{}{}: default: } } From 6e5ba9585f7e607462ccd5b990d66c4612f85849 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 28 Jul 2025 10:28:01 -0400 Subject: [PATCH 045/221] mcp: remove create-repo.sh (#178) It was a one-off script whose purpose has been served. --- mcp/create-repo.sh | 80 ---------------------------------------------- 1 file changed, 80 deletions(-) delete mode 100755 mcp/create-repo.sh diff --git a/mcp/create-repo.sh b/mcp/create-repo.sh deleted file mode 100755 index a68964bb..00000000 --- a/mcp/create-repo.sh +++ /dev/null @@ -1,80 +0,0 @@ -#!/bin/bash - -# This script creates an MCP SDK repo from x/tools/internal/mcp (and friends). -# It will be used as a one-off to create github.com/modelcontextprotocol/go-sdk. -# -# Requires https://github.com/newren/git-filter-repo. - -set -eu - -# Check if exactly one argument is provided -if [ "$#" -ne 1 ]; then - echo "create-repo.sh: create a standalone mcp SDK repo from x/tools" - echo "Usage: $0 " - exit 1 -fi >&2 - -src=$(go list -m -f {{.Dir}} golang.org/x/tools) -dest="$1" - -echo "Filtering MCP commits from ${src} to ${dest}..." >&2 - -startdir=$(pwd) -tempdir=$(mktemp -d) -function cleanup { - echo "cleaning up ${tempdir}" - rm -rf "$tempdir" -} >&2 -trap cleanup EXIT SIGINT - -echo "Checking out to ${tempdir}" - -git clone --bare "${src}" "${tempdir}" -git -C "${tempdir}" --git-dir=. filter-repo \ - --path internal/mcp/jsonschema --path-rename internal/mcp/jsonschema:jsonschema \ - --path internal/mcp/design --path-rename internal/mcp/design:design \ - --path internal/mcp/examples --path-rename internal/mcp/examples:examples \ - --path internal/mcp/internal --path-rename internal/mcp/internal:internal \ - --path internal/mcp/README.md --path-rename internal/mcp/README.md:README.md \ - --path internal/mcp/CONTRIBUTING.md --path-rename internal/mcp/CONTRIBUTING.md:CONTRIBUTING.md \ - --path internal/mcp --path-rename internal/mcp:mcp \ - --path internal/jsonrpc2_v2 --path-rename internal/jsonrpc2_v2:internal/jsonrpc2 \ - --path internal/xcontext \ - --replace-text "${startdir}/mcp-repo-replace.txt" \ - --force -mkdir ${dest} -cd "${dest}" -git init - -cat << EOF > LICENSE -MIT License - -Copyright (c) 2025 Go MCP SDK Authors - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -EOF - -git add LICENSE && git commit -m "Initial commit: add LICENSE" -git remote add filtered_source "${tempdir}" -git pull filtered_source master --allow-unrelated-histories -git remote remove filtered_source -go mod init github.com/modelcontextprotocol/go-sdk && go get go@1.23.0 -go mod tidy -git add go.mod go.sum -git commit -m "all: add go.mod and go.sum file" From 261a63ccc01e53f757a79936fbd9d6df463013fd Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 28 Jul 2025 10:29:59 -0400 Subject: [PATCH 046/221] mcp: don't panic on bad JSON RPC input (#177) Log internal errors from the jsonprc2 package, instead of panicking. Fixes #175. --- mcp/transport.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mcp/transport.go b/mcp/transport.go index a7de5061..0441c617 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "log" "net" "os" "sync" @@ -120,6 +121,7 @@ func connect[H handler](ctx context.Context, t Transport, b binder[H]) (H, error OnDone: func() { b.disconnect(h) }, + OnInternalError: func(err error) { log.Printf("jsonrpc2 error: %v", err) }, }) assert(preempter.conn != nil, "unbound preempter") h.setConn(conn) From 619bc41581e699bf64ad13ec334fadd2fbc327a0 Mon Sep 17 00:00:00 2001 From: Koichi Shiraishi Date: Tue, 29 Jul 2025 01:18:03 +0900 Subject: [PATCH 047/221] internal/jsonrpc2,mcp: fix typos (#183) - add links to doc strings - use correct symbol and package names --- internal/jsonrpc2/conn.go | 4 ++-- mcp/server.go | 7 +++---- mcp/shared.go | 4 ++-- mcp/transport.go | 6 +++--- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 32239454..fbe0688b 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -125,7 +125,7 @@ func (c *Connection) updateInFlight(f func(*inFlightState)) { // that and avoided making any updates that would cause the state to be // non-idle.) if !s.idle() { - panic("jsonrpc2_v2: updateInFlight transitioned to non-idle when already done") + panic("jsonrpc2: updateInFlight transitioned to non-idle when already done") } return default: @@ -718,7 +718,7 @@ func (c *Connection) processResult(from any, req *incomingRequest, result any, e req.cancel() c.updateInFlight(func(s *inFlightState) { if s.incoming == 0 { - panic("jsonrpc2_v2: processResult called when incoming count is already zero") + panic("jsonrpc2: processResult called when incoming count is already zero") } s.incoming-- }) diff --git a/mcp/server.go b/mcp/server.go index 960d35f2..4270578d 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -29,7 +29,7 @@ const DefaultPageSize = 1000 // A Server is an instance of an MCP server. // // Servers expose server-side MCP features, which can serve one or more MCP -// sessions by using [Server.Start] or [Server.Run]. +// sessions by using [Server.Run]. type Server struct { // fixed at creation impl *Implementation @@ -74,8 +74,7 @@ type ServerOptions struct { // NewServer creates a new MCP server. The resulting server has no features: // add features using the various Server.AddXXX methods, and the [AddTool] function. // -// The server can be connected to one or more MCP clients using [Server.Start] -// or [Server.Run]. +// The server can be connected to one or more MCP clients using [Server.Run]. // // The first argument must not be nil. // @@ -733,7 +732,7 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, } // For the streamable transport, we need the request ID to correlate // server->client calls and notifications to the incoming request from which - // they originated. See [idContext] for details. + // they originated. See [idContextKey] for details. ctx = context.WithValue(ctx, idContextKey{}, req.ID) return handleReceive(ctx, ss, req) } diff --git a/mcp/shared.go b/mcp/shared.go index 49dd2b10..cb5234a4 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -44,7 +44,7 @@ type MethodHandler[S Session] func(ctx context.Context, _ S, method string, para // the compiler would complain. type methodHandler any // MethodHandler[*ClientSession] | MethodHandler[*ServerSession] -// A Session is either a ClientSession or a ServerSession. +// A Session is either a [ClientSession] or a [ServerSession]. type Session interface { *ClientSession | *ServerSession // ID returns the session ID, or the empty string if there is none. @@ -57,7 +57,7 @@ type Session interface { getConn() *jsonrpc2.Connection } -// Middleware is a function from MethodHandlers to MethodHandlers. +// Middleware is a function from [MethodHandler] to [MethodHandler]. type Middleware[S Session] func(MethodHandler[S]) MethodHandler[S] // addMiddleware wraps the handler in the middleware functions. diff --git a/mcp/transport.go b/mcp/transport.go index 0441c617..5175f6f0 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -134,7 +134,7 @@ type canceller struct { conn *jsonrpc2.Connection } -// Preempt implements jsonrpc2.Preempter. +// Preempt implements [jsonrpc2.Preempter]. func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result any, err error) { if req.Method == "notifications/cancelled" { var params CancelledParams @@ -203,7 +203,7 @@ type loggingConn struct { func (c *loggingConn) SessionID() string { return c.delegate.SessionID() } -// loggingReader is a stream middleware that logs incoming messages. +// Read is a stream middleware that logs incoming messages. func (s *loggingConn) Read(ctx context.Context) (jsonrpc.Message, error) { msg, err := s.delegate.Read(ctx) if err != nil { @@ -218,7 +218,7 @@ func (s *loggingConn) Read(ctx context.Context) (jsonrpc.Message, error) { return msg, err } -// loggingWriter is a stream middleware that logs outgoing messages. +// Write is a stream middleware that logs outgoing messages. func (s *loggingConn) Write(ctx context.Context, msg jsonrpc.Message) error { err := s.delegate.Write(ctx, msg) if err != nil { From a5aa370ea07c3dc28d353a53f8fc49ed6ae1fd92 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 28 Jul 2025 13:58:37 -0400 Subject: [PATCH 048/221] mcp: various cleanups to streamable code (#184) - move done field outside of mutex hat - refactor handler into a function that returns `error` to avoid repeated `http.Error` calls - use an atomic for the signal channel to simplify locking --- mcp/streamable.go | 129 ++++++++++++++++++++++------------------------ 1 file changed, 61 insertions(+), 68 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index e15c6a69..7b814012 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -204,11 +204,11 @@ type StreamableServerTransport struct { id string opts StreamableServerTransportOptions incoming chan jsonrpc.Message // messages from the client to the server + done chan struct{} mu sync.Mutex // Sessions are closed exactly once. isDone bool - done chan struct{} // Sessions can have multiple logical connections, corresponding to HTTP // requests. Additionally, logical sessions may be resumed by subsequent HTTP @@ -242,23 +242,26 @@ type stream struct { // ID 0 is used for messages that don't correlate with an incoming request. id StreamID - // These mutable fields are protected by the mutex of the corresponding StreamableServerTransport. + // signal is a 1-buffered channel, owned by an incoming HTTP request, that signals + // that there are messages available to write into the HTTP response. + // In addition, the presence of a channel guarantees that at most one HTTP response + // can receive messages for a logical stream. After claiming the stream, incoming + // requests should read from outgoing, to ensure that no new messages are missed. + // + // To simplify locking, signal is an atomic. We need an atomic.Pointer, because + // you can't set an atomic.Value to nil. + // + // Lifecycle: each channel value persists for the duration of an HTTP POST or + // GET request for the given streamID. + signal atomic.Pointer[chan struct{}] + + // The following mutable fields are protected by the mutex of the containing + // StreamableServerTransport. // outgoing is the list of outgoing messages, enqueued by server methods that // write notifications and responses, and dequeued by streamResponse. outgoing [][]byte - // signal is a 1-buffered channel, owned by an - // incoming HTTP request, that signals that there are messages available to - // write into the HTTP response. This guarantees that at most one HTTP - // response can receive messages for a logical stream. After claiming - // the stream, incoming requests should read from outgoing, to ensure - // that no new messages are missed. - // - // Lifecycle: persists for the duration of an HTTP POST or GET - // request for the given streamID. - signal chan struct{} - // streamRequests is the set of unanswered incoming RPCs for the stream. // // Lifecycle: requests values persist until the requests have been @@ -274,6 +277,11 @@ func newStream(id StreamID) *stream { } } +func signalChanPtr() *chan struct{} { + c := make(chan struct{}, 1) + return &c +} + // A StreamID identifies a stream of SSE events. It is unique within the stream's // [ServerSession]. type StreamID int64 @@ -310,19 +318,25 @@ type idContextKey struct{} // ServeHTTP handles a single HTTP request for the session. func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { + status := 0 + message := "" switch req.Method { case http.MethodGet: - t.serveGET(w, req) + status, message = t.serveGET(w, req) case http.MethodPost: - t.servePOST(w, req) + status, message = t.servePOST(w, req) default: // Should not be reached, as this is checked in StreamableHTTPHandler.ServeHTTP. w.Header().Set("Allow", "GET, POST") - http.Error(w, "unsupported method", http.StatusMethodNotAllowed) + status = http.StatusMethodNotAllowed + message = "unsupported method" + } + if status != 0 && status != http.StatusOK { + http.Error(w, message, status) } } -func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) { +func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) (int, string) { // connID 0 corresponds to the default GET request. id := StreamID(0) // By default, we haven't seen a last index. Since indices start at 0, we represent @@ -334,49 +348,39 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re var ok bool id, lastIdx, ok = parseEventID(eid) if !ok { - http.Error(w, fmt.Sprintf("malformed Last-Event-ID %q", eid), http.StatusBadRequest) - return + return http.StatusBadRequest, fmt.Sprintf("malformed Last-Event-ID %q", eid) } } t.mu.Lock() stream, ok := t.streams[id] + t.mu.Unlock() if !ok { - http.Error(w, "unknown stream", http.StatusBadRequest) - t.mu.Unlock() - return + return http.StatusBadRequest, "unknown stream" } - if stream.signal != nil { - http.Error(w, "stream ID conflicts with ongoing stream", http.StatusBadRequest) - t.mu.Unlock() - return + if !stream.signal.CompareAndSwap(nil, signalChanPtr()) { + // The CAS returned false, meaning that the comparison failed: stream.signal is not nil. + return http.StatusBadRequest, "stream ID conflicts with ongoing stream" } - stream.signal = make(chan struct{}, 1) - t.mu.Unlock() - - t.streamResponse(stream, w, req, lastIdx) + return t.streamResponse(stream, w, req, lastIdx) } -func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) { +func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) (int, string) { if len(req.Header.Values("Last-Event-ID")) > 0 { - http.Error(w, "can't send Last-Event-ID for POST request", http.StatusBadRequest) - return + return http.StatusBadRequest, "can't send Last-Event-ID for POST request" } // Read incoming messages. body, err := io.ReadAll(req.Body) if err != nil { - http.Error(w, "failed to read body", http.StatusBadRequest) - return + return http.StatusBadRequest, "failed to read body" } if len(body) == 0 { - http.Error(w, "POST requires a non-empty body", http.StatusBadRequest) - return + return http.StatusBadRequest, "POST requires a non-empty body" } incoming, _, err := readBatch(body) if err != nil { - http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) - return + return http.StatusBadRequest, fmt.Sprintf("malformed payload: %v", err) } requests := make(map[jsonrpc.ID]struct{}) for _, msg := range incoming { @@ -396,8 +400,8 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R t.requestStreams[reqID] = stream.id stream.requests[reqID] = struct{}{} } - stream.signal = make(chan struct{}, 1) t.mu.Unlock() + stream.signal.Store(signalChanPtr()) // Publish incoming messages. for _, msg := range incoming { @@ -407,23 +411,12 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R // TODO(rfindley): consider optimizing for a single incoming request, by // responding with application/json when there is only a single message in // the response. - t.streamResponse(stream, w, req, -1) + return t.streamResponse(stream, w, req, -1) } // lastIndex is the index of the last seen event if resuming, else -1. -func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) { - defer func() { - t.mu.Lock() - stream.signal = nil - t.mu.Unlock() - }() - - t.mu.Lock() - // Although there is a gap in locking between when stream.signal is set and here, - // it cannot change, because it is changed only when non-nil, and it is only - // set to nil in the defer above. - signal := stream.signal - t.mu.Unlock() +func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) (int, string) { + defer stream.signal.Store(nil) writes := 0 @@ -437,7 +430,7 @@ func (t *StreamableServerTransport) streamResponse(stream *stream, w http.Respon } if _, err := writeEvent(w, e); err != nil { // Connection closed or broken. - // TODO: log when we add server-side logging. + // TODO(#170): log when we add server-side logging. return false } writes++ @@ -460,13 +453,12 @@ func (t *StreamableServerTransport) streamResponse(stream *stream, w http.Respon if errors.Is(err, ErrEventsPurged) { status = http.StatusInsufficientStorage } - http.Error(w, err.Error(), status) - return + return status, err.Error() } // The iterator yields events beginning just after lastIndex, or it would have // yielded an error. if !write(data) { - return + return 0, "" } } } @@ -481,11 +473,10 @@ stream: for _, data := range outgoing { if err := t.opts.EventStore.Append(req.Context(), t.SessionID(), stream.id, data); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + return http.StatusInternalServerError, err.Error() } if !write(data) { - return + return 0, "" } } @@ -495,22 +486,22 @@ stream: // If all requests have been handled and replied to, we should terminate this connection. // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." // §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server - // TODO(jba): why not terminate regardless of http method? + // TODO(jba,findleyr): why not terminate regardless of http method? if req.Method == http.MethodPost && nOutstanding == 0 { if writes == 0 { // Spec: If the server accepts the input, the server MUST return HTTP // status code 202 Accepted with no body. w.WriteHeader(http.StatusAccepted) } - return + return 0, "" } select { - case <-signal: // there are new outgoing messages + case <-*stream.signal.Load(): // there are new outgoing messages // return to top of loop case <-t.done: // session is closed if writes == 0 { - http.Error(w, "session terminated", http.StatusGone) + return http.StatusGone, "session terminated" } break stream case <-req.Context().Done(): @@ -520,6 +511,7 @@ stream: break stream } } + return 0, "" } // Event IDs: encode both the logical connection ID and the index, as @@ -625,10 +617,11 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa delete(stream.requests, replyTo) } - // Signal work. - if stream.signal != nil { + // Signal streamResponse that new work is available. + signalp := stream.signal.Load() + if signalp != nil { select { - case stream.signal <- struct{}{}: + case *signalp <- struct{}{}: default: } } From 2b6f7b51854e813c06ddec75a1abb5da33345574 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Tue, 29 Jul 2025 05:20:08 -0400 Subject: [PATCH 049/221] mcp/server: add initial capability server option (#182) This CL adds an option for servers to advertise support for features that have not been added yet. This includes Prompts, Tools, and Resources as these could be dynamically added after initialization. Fixes: #135. --- mcp/server.go | 16 +++++++++++++--- mcp/server_test.go | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 4270578d..e88455a2 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -69,6 +69,15 @@ type ServerOptions struct { SubscribeHandler func(context.Context, *SubscribeParams) error // Function called when a client session unsubscribes from a resource. UnsubscribeHandler func(context.Context, *UnsubscribeParams) error + // If true, advertises the prompts capability during initialization, + // even if no prompts have been registered. + HasPrompts bool + // If true, advertises the resources capability during initialization, + // even if no resources have been registered. + HasResources bool + // If true, advertises the tools capability during initialization, + // even if no tools have been registered. + HasTools bool } // NewServer creates a new MCP server. The resulting server has no features: @@ -229,16 +238,17 @@ func (s *Server) capabilities() *serverCapabilities { defer s.mu.Unlock() caps := &serverCapabilities{ + // TODO(samthanawalla): check for completionHandler before advertising capability. Completions: &completionCapabilities{}, Logging: &loggingCapabilities{}, } - if s.tools.len() > 0 { + if s.opts.HasTools || s.tools.len() > 0 { caps.Tools = &toolCapabilities{ListChanged: true} } - if s.prompts.len() > 0 { + if s.opts.HasPrompts || s.prompts.len() > 0 { caps.Prompts = &promptCapabilities{ListChanged: true} } - if s.resources.len() > 0 || s.resourceTemplates.len() > 0 { + if s.opts.HasResources || s.resources.len() > 0 || s.resourceTemplates.len() > 0 { caps.Resources = &resourceCapabilities{ListChanged: true} if s.opts.SubscribeHandler != nil { caps.Resources.Subscribe = true diff --git a/mcp/server_test.go b/mcp/server_test.go index 9a36c962..79c607e9 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -333,6 +333,22 @@ func TestServerCapabilities(t *testing.T) { Tools: &toolCapabilities{ListChanged: true}, }, }, + { + name: "With initial capabilities", + configureServer: func(s *Server) {}, + serverOpts: ServerOptions{ + HasPrompts: true, + HasResources: true, + HasTools: true, + }, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Prompts: &promptCapabilities{ListChanged: true}, + Resources: &resourceCapabilities{ListChanged: true}, + Tools: &toolCapabilities{ListChanged: true}, + }, + }, } for _, tc := range testCases { From 64b5b9125e21689438f17bc58bd29523519d7ef4 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 29 Jul 2025 10:28:49 -0400 Subject: [PATCH 050/221] mcp: support application/json in streamable client (#181) The client handles POST responses with content type application/json. For: #10 Fixes #129 --- mcp/streamable.go | 14 ++++++++++-- mcp/streamable_test.go | 48 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 7b814012..80304a04 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -813,6 +813,8 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } +// postMessage POSTs msg to the server and reads the response. +// It returns the session ID from the response. func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) { data, err := jsonrpc2.EncodeMessage(msg) if err != nil { @@ -849,9 +851,17 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string // Section 2.1: The SSE stream is initiated after a POST. go s.handleSSE(resp) case "application/json": - // TODO: read the body and send to s.incoming (in a select that also receives from s.done). + body, err := io.ReadAll(resp.Body) resp.Body.Close() - return "", fmt.Errorf("streamable HTTP client does not yet support raw JSON responses") + if err != nil { + return "", err + } + select { + case s.incoming <- body: + case <-s.done: + // The connection was closed by the client; exit gracefully. + } + return sessionID, nil default: resp.Body.Close() return "", fmt.Errorf("unsupported content type %q", ct) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 864265e5..185bc638 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -158,7 +158,8 @@ func TestClientReplay(t *testing.T) { client := NewClient(testImpl, &ClientOptions{ ProgressNotificationHandler: func(ctx context.Context, cc *ClientSession, params *ProgressNotificationParams) { notifications <- params.Message - }}) + }, + }) clientSession, err := client.Connect(ctx, NewStreamableClientTransport(proxy.URL, nil)) if err != nil { t.Fatalf("client.Connect() failed: %v", err) @@ -678,6 +679,51 @@ func mustMarshal(t *testing.T, v any) json.RawMessage { return data } +func TestStreamableClientTransportApplicationJSON(t *testing.T) { + // Test handling of application/json responses. + ctx := context.Background() + resp := func(id int64, result any, err error) *jsonrpc.Response { + return &jsonrpc.Response{ + ID: jsonrpc2.Int64ID(id), + Result: mustMarshal(t, result), + Error: err, + } + } + initResult := &InitializeResult{ + Capabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Tools: &toolCapabilities{ListChanged: true}, + }, + ProtocolVersion: latestProtocolVersion, + ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, + } + initResp := resp(1, initResult, nil) + + serverHandler := func(w http.ResponseWriter, r *http.Request) { + data, err := jsonrpc2.EncodeMessage(initResp) + if err != nil { + t.Fatal(err) + } + w.Header().Set("Content-Type", "application/json") + w.Write(data) + } + + httpServer := httptest.NewServer(http.HandlerFunc(serverHandler)) + defer httpServer.Close() + + transport := NewStreamableClientTransport(httpServer.URL, nil) + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer session.Close() + if diff := cmp.Diff(initResult, session.initializeResult); diff != "" { + t.Errorf("mismatch (-want, +got):\n%s", diff) + } +} + func TestEventID(t *testing.T) { tests := []struct { sid StreamID From 3ac4ca9c5b80b2d155168b5f37259eb21fcd26f4 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 29 Jul 2025 10:49:16 -0400 Subject: [PATCH 051/221] mcp: streamable.go: clarifications (#189) Clarify or remove TODOs. Make some code clearer. --- mcp/streamable.go | 58 +++++++++++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 80304a04..0921c22d 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -42,8 +42,8 @@ type StreamableHTTPHandler struct { // StreamableHTTPOptions is a placeholder options struct for future // configuration of the StreamableHTTP handler. type StreamableHTTPOptions struct { - // TODO(rfindley): support configurable session ID generation and event - // store, session retention, and event retention. + // TODO: support configurable session ID generation (?) + // TODO: support session retention (?) } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -61,7 +61,7 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea // closeAll closes all ongoing sessions. // // TODO(rfindley): investigate the best API for callers to configure their -// session lifecycle. +// session lifecycle. (?) // // Should we allow passing in a session store? That would allow the handler to // be stateless. @@ -118,7 +118,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } h.sessionsMu.Lock() - delete(h.sessions, session.id) + delete(h.sessions, session.sessionID) h.sessionsMu.Unlock() session.Close() w.WriteHeader(http.StatusNoContent) @@ -149,7 +149,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } h.sessionsMu.Lock() - h.sessions[s.id] = s + h.sessions[s.sessionID] = s h.sessionsMu.Unlock() session = s } @@ -176,7 +176,7 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp opts = &StreamableServerTransportOptions{} } t := &StreamableServerTransport{ - id: sessionID, + sessionID: sessionID, incoming: make(chan jsonrpc.Message, 10), done: make(chan struct{}), streams: make(map[StreamID]*stream), @@ -193,7 +193,7 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp } func (t *StreamableServerTransport) SessionID() string { - return t.id + return t.sessionID } // A StreamableServerTransport implements the [Transport] interface for a @@ -201,10 +201,10 @@ func (t *StreamableServerTransport) SessionID() string { type StreamableServerTransport struct { nextStreamID atomic.Int64 // incrementing next stream ID - id string - opts StreamableServerTransportOptions - incoming chan jsonrpc.Message // messages from the client to the server - done chan struct{} + sessionID string + opts StreamableServerTransportOptions + incoming chan jsonrpc.Message // messages from the client to the server + done chan struct{} mu sync.Mutex // Sessions are closed exactly once. @@ -217,17 +217,20 @@ type StreamableServerTransport struct { // Therefore, we use a logical connection ID to key the connection state, and // perform the accounting described below when incoming HTTP requests are // handled. - // - // TODO(rfindley): simplify. // streams holds the logical streams for this session, keyed by their ID. + // TODO: streams are never deleted, so the memory for a connection grows without + // bound. If we deleted a stream when the response is sent, we would lose the ability + // to replay if there was a cut just before the response was transmitted. + // Perhaps we could have a TTL for streams that starts just after the response. streams map[StreamID]*stream // requestStreams maps incoming requests to their logical stream ID. // // Lifecycle: requestStreams persists for the duration of the session. // - // TODO(rfindley): clean up once requests are handled. + // TODO(rfindley): clean up once requests are handled. See the TODO for streams + // above. requestStreams map[jsonrpc.ID]StreamID } @@ -288,7 +291,7 @@ type StreamID int64 // Connect implements the [Transport] interface. // -// TODO(rfindley): Connect should return a new object. +// TODO(rfindley): Connect should return a new object. (Why?) func (s *StreamableServerTransport) Connect(context.Context) (Connection, error) { return s, nil } @@ -411,6 +414,8 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R // TODO(rfindley): consider optimizing for a single incoming request, by // responding with application/json when there is only a single message in // the response. + // (But how would we know there is only a single message? For example, couldn't + // a progress notification be sent before a response on the same context?) return t.streamResponse(stream, w, req, -1) } @@ -437,7 +442,7 @@ func (t *StreamableServerTransport) streamResponse(stream *stream, w http.Respon return true } - w.Header().Set(sessionIDHeader, t.id) + w.Header().Set(sessionIDHeader, t.sessionID) w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Connection", "keep-alive") @@ -486,7 +491,9 @@ stream: // If all requests have been handled and replied to, we should terminate this connection. // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." // §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server - // TODO(jba,findleyr): why not terminate regardless of http method? + // We only want to terminate POSTs, and GETs that are replaying. The general-purpose GET + // (stream ID 0) will never have requests, and should remain open indefinitely. + // TODO: implement the GET case. if req.Method == http.MethodPost && nOutstanding == 0 { if writes == 0 { // Spec: If the server accepts the input, the server MUST return HTTP @@ -563,11 +570,12 @@ func (t *StreamableServerTransport) Read(ctx context.Context) (jsonrpc.Message, // Write implements the [Connection] interface. func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Message) error { // Find the incoming request that this write relates to, if any. - var forRequest, replyTo jsonrpc.ID + var forRequest jsonrpc.ID + isResponse := false if resp, ok := msg.(*jsonrpc.Response); ok { // If the message is a response, it relates to its request (of course). forRequest = resp.ID - replyTo = resp.ID + isResponse = true } else { // Otherwise, we check to see if it request was made in the context of an // ongoing request. This may not be the case if the request way made with @@ -611,10 +619,12 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa stream = t.streams[0] } + // TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == 0 + // and the client never did a GET), then memory will grow without bound. Consider a mitigation. stream.outgoing = append(stream.outgoing, data) - if replyTo.IsValid() { + if isResponse { // Once we've put the reply on the queue, it's no longer outstanding. - delete(stream.requests, replyTo) + delete(stream.requests, forRequest) } // Signal streamResponse that new work is available. @@ -635,7 +645,9 @@ func (t *StreamableServerTransport) Close() error { if !t.isDone { t.isDone = true close(t.done) - return t.opts.EventStore.SessionClosed(context.TODO(), t.id) + // TODO: find a way to plumb a context here, or an event store with a long-running + // close operation can take arbitrary time. Alternative: impose a fixed timeout here. + return t.opts.EventStore.SessionClosed(context.TODO(), t.sessionID) } return nil } @@ -643,8 +655,6 @@ func (t *StreamableServerTransport) Close() error { // A StreamableClientTransport is a [Transport] that can communicate with an MCP // endpoint serving the streamable HTTP transport defined by the 2025-03-26 // version of the spec. -// -// TODO(rfindley): support retries and resumption tokens. type StreamableClientTransport struct { url string opts StreamableClientTransportOptions From 64493be392fb2d8f8124e73059f854b00ad67fcc Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Wed, 30 Jul 2025 12:26:40 -0400 Subject: [PATCH 052/221] mcp: reject requests with bad IDs (#202) Add an isRequest field to methodInfo, used it to reject non-notification requests that lack a valid ID. Furthermore, lift this validation to the transport layer for HTTP server transports, so that we can preemptively reject bad HTTP requests. Fixes #194 Fixes #197 --- mcp/client.go | 22 ++++---- mcp/server.go | 32 +++++------ mcp/shared.go | 34 ++++++++++-- mcp/sse.go | 6 +++ mcp/sse_test.go | 43 +++++++++++++-- mcp/streamable.go | 12 ++++- mcp/streamable_test.go | 14 ++++- .../conformance/server/missing_fields.txtar | 54 +++++++++++++++++++ 8 files changed, 178 insertions(+), 39 deletions(-) create mode 100644 mcp/testdata/conformance/server/missing_fields.txtar diff --git a/mcp/client.go b/mcp/client.go index b386294c..2cd81e04 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -287,16 +287,16 @@ func (c *Client) AddReceivingMiddleware(middleware ...Middleware[*ClientSession] // clientMethodInfos maps from the RPC method name to serverMethodInfos. var clientMethodInfos = map[string]methodInfo{ - methodComplete: newMethodInfo(sessionMethod((*ClientSession).Complete)), - methodPing: newMethodInfo(sessionMethod((*ClientSession).ping)), - methodListRoots: newMethodInfo(clientMethod((*Client).listRoots)), - methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage)), - notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler)), - notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler)), - notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler)), - notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler)), - notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler)), - notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler)), + methodComplete: newMethodInfo(sessionMethod((*ClientSession).Complete), true), + methodPing: newMethodInfo(sessionMethod((*ClientSession).ping), true), + methodListRoots: newMethodInfo(clientMethod((*Client).listRoots), true), + methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage), true), + notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler), false), + notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler), false), + notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler), false), + notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), false), + notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler), false), + notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler), false), } func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { @@ -323,7 +323,7 @@ func (cs *ClientSession) receivingMethodHandler() methodHandler { return cs.client.receivingMethodHandler_ } -// getConn implements [session.getConn]. +// getConn implements [Session.getConn]. func (cs *ClientSession) getConn() *jsonrpc2.Connection { return cs.conn } func (*ClientSession) ping(context.Context, *PingParams) (*emptyResult, error) { diff --git a/mcp/server.go b/mcp/server.go index e88455a2..9b823e7f 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -688,22 +688,22 @@ func (s *Server) AddReceivingMiddleware(middleware ...Middleware[*ServerSession] // serverMethodInfos maps from the RPC method name to serverMethodInfos. var serverMethodInfos = map[string]methodInfo{ - methodComplete: newMethodInfo(serverMethod((*Server).complete)), - methodInitialize: newMethodInfo(sessionMethod((*ServerSession).initialize)), - methodPing: newMethodInfo(sessionMethod((*ServerSession).ping)), - methodListPrompts: newMethodInfo(serverMethod((*Server).listPrompts)), - methodGetPrompt: newMethodInfo(serverMethod((*Server).getPrompt)), - methodListTools: newMethodInfo(serverMethod((*Server).listTools)), - methodCallTool: newMethodInfo(serverMethod((*Server).callTool)), - methodListResources: newMethodInfo(serverMethod((*Server).listResources)), - methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates)), - methodReadResource: newMethodInfo(serverMethod((*Server).readResource)), - methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel)), - methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe)), - methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe)), - notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler)), - notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler)), - notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler)), + methodComplete: newMethodInfo(serverMethod((*Server).complete), true), + methodInitialize: newMethodInfo(sessionMethod((*ServerSession).initialize), true), + methodPing: newMethodInfo(sessionMethod((*ServerSession).ping), true), + methodListPrompts: newMethodInfo(serverMethod((*Server).listPrompts), true), + methodGetPrompt: newMethodInfo(serverMethod((*Server).getPrompt), true), + methodListTools: newMethodInfo(serverMethod((*Server).listTools), true), + methodCallTool: newMethodInfo(serverMethod((*Server).callTool), true), + methodListResources: newMethodInfo(serverMethod((*Server).listResources), true), + methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates), true), + methodReadResource: newMethodInfo(serverMethod((*Server).readResource), true), + methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel), true), + methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe), true), + methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe), true), + notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler), false), + notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler), false), + notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler), false), } func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos } diff --git a/mcp/shared.go b/mcp/shared.go index cb5234a4..0bc7c793 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -123,9 +123,9 @@ func defaultReceivingMethodHandler[S Session](ctx context.Context, session S, me } func handleReceive[S Session](ctx context.Context, session S, req *jsonrpc.Request) (Result, error) { - info, ok := session.receivingMethodInfos()[req.Method] - if !ok { - return nil, jsonrpc2.ErrNotHandled + info, err := checkRequest(req, session.receivingMethodInfos()) + if err != nil { + return nil, err } params, err := info.unmarshalParams(req.Params) if err != nil { @@ -141,8 +141,30 @@ func handleReceive[S Session](ctx context.Context, session S, req *jsonrpc.Reque return res, nil } +// checkRequest checks the given request against the provided method info, to +// ensure it is a valid MCP request. +// +// If valid, the relevant method info is returned. Otherwise, a non-nil error +// is returned describing why the request is invalid. +// +// This is extracted from request handling so that it can be called in the +// transport layer to preemptively reject bad requests. +func checkRequest(req *jsonrpc.Request, infos map[string]methodInfo) (methodInfo, error) { + info, ok := infos[req.Method] + if !ok { + return methodInfo{}, fmt.Errorf("%w: %q unsupported", jsonrpc2.ErrNotHandled, req.Method) + } + if info.isRequest && !req.ID.IsValid() { + return methodInfo{}, fmt.Errorf("%w: %q missing ID", jsonrpc2.ErrInvalidRequest, req.Method) + } + return info, nil +} + // methodInfo is information about sending and receiving a method. type methodInfo struct { + // isRequest reports whether the method is a JSON-RPC request. + // Otherwise, the method is treated as a notification. + isRequest bool // Unmarshal params from the wire into a Params struct. // Used on the receive side. unmarshalParams func(json.RawMessage) (Params, error) @@ -169,8 +191,12 @@ type paramsPtr[T any] interface { } // newMethodInfo creates a methodInfo from a typedMethodHandler. -func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHandler[S, P, R]) methodInfo { +// +// If isRequest is set, the method is treated as a request rather than a +// notification. +func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHandler[S, P, R], isRequest bool) methodInfo { return methodInfo{ + isRequest: isRequest, unmarshalParams: func(m json.RawMessage) (Params, error) { var p P if m != nil { diff --git a/mcp/sse.go b/mcp/sse.go index 92a7ca1a..cf44276b 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -129,6 +129,12 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) http.Error(w, "failed to parse body", http.StatusBadRequest) return } + if req, ok := msg.(*jsonrpc.Request); ok { + if _, err := checkRequest(req, serverMethodInfos); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + } select { case t.incoming <- msg: w.WriteHeader(http.StatusAccepted) diff --git a/mcp/sse_test.go b/mcp/sse_test.go index b4e8ebad..d486732c 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -5,8 +5,10 @@ package mcp import ( + "bytes" "context" "fmt" + "io" "net/http" "net/http/httptest" "sync/atomic" @@ -24,10 +26,10 @@ func TestSSEServer(t *testing.T) { sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }) - conns := make(chan *ServerSession, 1) - sseHandler.onConnection = func(cc *ServerSession) { + serverSessions := make(chan *ServerSession, 1) + sseHandler.onConnection = func(ss *ServerSession) { select { - case conns <- cc: + case serverSessions <- ss: default: } } @@ -54,7 +56,7 @@ func TestSSEServer(t *testing.T) { if err := cs.Ping(ctx, nil); err != nil { t.Fatal(err) } - ss := <-conns + ss := <-serverSessions gotHi, err := cs.CallTool(ctx, &CallToolParams{ Name: "greet", Arguments: map[string]any{"Name": "user"}, @@ -76,6 +78,39 @@ func TestSSEServer(t *testing.T) { t.Error("Expected custom HTTP client to be used, but it wasn't") } + t.Run("badrequests", func(t *testing.T) { + msgEndpoint := cs.mcpConn.(*sseClientConn).msgEndpoint.String() + + // Test some invalid data, and verify that we get 400s. + badRequests := []struct { + name string + body string + responseContains string + }{ + {"not a method", `{"jsonrpc":"2.0", "method":"notamethod"}`, "not handled"}, + {"missing ID", `{"jsonrpc":"2.0", "method":"ping"}`, "missing ID"}, + } + for _, r := range badRequests { + t.Run(r.name, func(t *testing.T) { + resp, err := http.Post(msgEndpoint, "application/json", bytes.NewReader([]byte(r.body))) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if got, want := resp.StatusCode, http.StatusBadRequest; got != want { + t.Errorf("Sending bad request %q: got status %d, want %d", r.body, got, want) + } + result, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Reading response: %v", err) + } + if !bytes.Contains(result, []byte(r.responseContains)) { + t.Errorf("Response body does not contain %q:\n%s", r.responseContains, string(result)) + } + }) + } + }) + // Test that closing either end of the connection terminates the other // end. if closeServerFirst { diff --git a/mcp/streamable.go b/mcp/streamable.go index 0921c22d..f12c3011 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -387,8 +387,16 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R } requests := make(map[jsonrpc.ID]struct{}) for _, msg := range incoming { - if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() { - requests[req.ID] = struct{}{} + if req, ok := msg.(*jsonrpc.Request); ok { + // Preemptively check that this is a valid request, so that we can fail + // the HTTP request. If we didn't do this, a request with a bad method or + // missing ID could be silently swallowed. + if _, err := checkRequest(req, serverMethodInfos); err != nil { + return http.StatusBadRequest, err.Error() + } + if req.ID.IsValid() { + requests[req.ID] = struct{}{} + } } } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 185bc638..1ba20bc1 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -280,7 +280,7 @@ func TestStreamableServerTransport(t *testing.T) { } // Predefined steps, to avoid repetition below. - initReq := req(1, "initialize", &InitializeParams{}) + initReq := req(1, methodInitialize, &InitializeParams{}) initResp := resp(1, &InitializeResult{ Capabilities: &serverCapabilities{ Completions: &completionCapabilities{}, @@ -290,7 +290,7 @@ func TestStreamableServerTransport(t *testing.T) { ProtocolVersion: latestProtocolVersion, ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, }, nil) - initializedMsg := req(0, "initialized", &InitializedParams{}) + initializedMsg := req(0, notificationInitialized, &InitializedParams{}) initialize := step{ Method: "POST", Send: []jsonrpc.Message{initReq}, @@ -438,6 +438,16 @@ func TestStreamableServerTransport(t *testing.T) { Method: "DELETE", StatusCode: http.StatusBadRequest, }, + { + Method: "POST", + Send: []jsonrpc.Message{req(1, "notamethod", nil)}, + StatusCode: http.StatusBadRequest, // notamethod is an invalid method + }, + { + Method: "POST", + Send: []jsonrpc.Message{req(0, "tools/call", &CallToolParams{Name: "tool"})}, + StatusCode: http.StatusBadRequest, // tools/call must have an ID + }, { Method: "POST", Send: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, diff --git a/mcp/testdata/conformance/server/missing_fields.txtar b/mcp/testdata/conformance/server/missing_fields.txtar new file mode 100644 index 00000000..c5aa3e55 --- /dev/null +++ b/mcp/testdata/conformance/server/missing_fields.txtar @@ -0,0 +1,54 @@ +Check robustness to missing fields: servers should reject and otherwise ignore +bad requests. + +Fixed bugs: +- No id in 'initialize' should not panic (#197). +- No id in 'ping' should not panic (#194). + +TODO: +- No params in 'initialize' should not panic (#195). + +-- prompts -- +code_review + +-- client -- +{ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } + } +} +{ + "jsonrpc": "2.0", + "id": 2, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } + } +} +{"jsonrpc":"2.0", "method":"ping"} + +-- server -- +{ + "jsonrpc": "2.0", + "id": 2, + "result": { + "capabilities": { + "completions": {}, + "logging": {}, + "prompts": { + "listChanged": true + } + }, + "protocolVersion": "2024-11-05", + "serverInfo": { + "name": "testServer", + "version": "v1.0.0" + } + } +} From 0fa4a6c65ba3eaa08e7ba22054e4fce0cfbc78c4 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 30 Jul 2025 12:53:56 -0400 Subject: [PATCH 053/221] CONTRIBUTING.md: timeout policy (#187) Explain that we'll close issues or PRs after a period of inactivity. This will reduce clutter on the issue and PR pages. --- CONTRIBUTING.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 40739735..726ac95f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -118,6 +118,12 @@ The CI system will automatically check that the README is up-to-date by running `make` and verifying no changes result. If you see a CI failure about the README being out of sync, follow the steps above to regenerate it. +## Timeouts + +If a contributor hasn't responded to issue questions or PR comments in two weeks, +the issue or PR may be closed. It can be reopened when the contributor can resume +work. + ## Code of conduct This project follows the [Go Community Code of Conduct](https://go.dev/conduct). From 982a0bceca8b60f8a61a386ea121c475b7fd6351 Mon Sep 17 00:00:00 2001 From: CSK <73425927+cr2007@users.noreply.github.com> Date: Wed, 30 Jul 2025 20:54:41 +0400 Subject: [PATCH 054/221] feat(devcontainer): Adds Dev Container configuration (#154) This fixes #151, which makes it easier for users to easily use GitHub Codespaces, providing them with a VS Code like environment with their preferred settings/themes and the required dependencies, all within the browser. --- .devcontainer/devcontainer.json | 26 ++++++++++++++++++++++++++ README.md | 2 ++ internal/readme/README.src.md | 2 ++ 3 files changed, 30 insertions(+) create mode 100644 .devcontainer/devcontainer.json diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..ea77d0f8 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,26 @@ +{ + "name": "Go MCP SDK Dev Container", + "image": "mcr.microsoft.com/devcontainers/go", + "features": { + "ghcr.io/devcontainers/features/git-lfs:1": {} + }, + "customizations": { + "vscode": { + "extensions": [ + "golang.go", + "ms-vsliveshare.vsliveshare", + "VisualStudioExptTeam.vscodeintellicode", + "eamodio.gitlens", + "usernamehw.errorlens", + "aaron-bond.better-comments", + "GitHub.vscode-github-actions", + "vscode-icons-team.vscode-icons" + ], + "settings": { + "workbench.iconTheme": "vscode-icons" + } + } + }, + + "postCreateCommand": "go mod tidy" +} diff --git a/README.md b/README.md index d899afdd..76800430 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # MCP Go SDK v0.2.0 +[![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/modelcontextprotocol/go-sdk) + ***BREAKING CHANGES*** This version contains breaking changes. diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index 0e239f81..bf9faa26 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -1,5 +1,7 @@ # MCP Go SDK v0.2.0 +[![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/modelcontextprotocol/go-sdk) + ***BREAKING CHANGES*** This version contains breaking changes. From bacca7a1d24bef242cbca2078c0daedea854fbe0 Mon Sep 17 00:00:00 2001 From: Francis Zhou Date: Thu, 31 Jul 2025 01:25:06 +0800 Subject: [PATCH 055/221] Implement json.Unmarshaler to CreateMessageResult (#191) When calling `ServerSession.CreateMessage`, will get an json unmarshal error because `CreateMessageResult.Content` is an interface, and it doesn't implement `json.Unmarshaller`. Implement the `json.Unmarshaller` to `CreateMessageResult`, let it able to unmarshal the client response without error --- mcp/mcp_test.go | 2 +- mcp/protocol.go | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 3c222eb4..4c728107 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -125,7 +125,7 @@ func TestEndToEnd(t *testing.T) { loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging opts := &ClientOptions{ CreateMessageHandler: func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) { - return &CreateMessageResult{Model: "aModel"}, nil + return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil }, ToolListChangedHandler: func(context.Context, *ClientSession, *ToolListChangedParams) { notificationChans["tools"] <- 0 }, PromptListChangedHandler: func(context.Context, *ClientSession, *PromptListChangedParams) { notificationChans["prompts"] <- 0 }, diff --git a/mcp/protocol.go b/mcp/protocol.go index 00dcd14d..3ca6cb5e 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -264,6 +264,23 @@ type CreateMessageResult struct { StopReason string `json:"stopReason,omitempty"` } +func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { + type result CreateMessageResult // avoid recursion + var wire struct { + result + Content *wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.result.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil { + return err + } + *r = CreateMessageResult(wire.result) + return nil +} + type GetPromptParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. From 9f48a04971eefa7556b2b134a8a9133d36198b8d Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Wed, 30 Jul 2025 14:59:47 -0400 Subject: [PATCH 056/221] mcp: remove unused functions (cleanup) (#207) Remove a few functions that were flagged by gopls as unused. We can always add them back if necessary. --- mcp/mcp_test.go | 14 -------------- mcp/server.go | 4 ---- 2 files changed, 18 deletions(-) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 4c728107..819edeb6 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -520,17 +520,6 @@ func handleEmbeddedResource(_ context.Context, _ *ServerSession, params *ReadRes }, nil } -// Add calls the given function to add the named features. -func add[T any](m map[string]T, add func(...T), names ...string) { - for _, name := range names { - feat, ok := m[name] - if !ok { - panic("missing feature " + name) - } - add(feat) - } -} - // errorCode returns the code associated with err. // If err is nil, it returns 0. // If there is no code, it returns -1. @@ -837,9 +826,6 @@ func traceCalls[S Session](w io.Writer, prefix string) Middleware[S] { } } -// A function, because schemas must form a tree (they have hidden state). -func falseSchema() *jsonschema.Schema { return &jsonschema.Schema{Not: &jsonschema.Schema{}} } - func nopHandler(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { return nil, nil } diff --git a/mcp/server.go b/mcp/server.go index 9b823e7f..220c28d3 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -423,10 +423,6 @@ func (s *Server) lookupResourceHandler(uri string) (ResourceHandler, string, boo // Lexical path traversal attacks, where the path has ".." elements that escape dir, // are always caught. Go 1.24 and above also protects against symlink-based attacks, // where symlinks under dir lead out of the tree. -func (s *Server) fileResourceHandler(dir string) ResourceHandler { - return fileResourceHandler(dir) -} - func fileResourceHandler(dir string) ResourceHandler { // Convert dir to an absolute path. dirFilepath, err := filepath.Abs(dir) From 56734edf465110c911f564f6bd23a20f5dea7af2 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Wed, 30 Jul 2025 15:19:51 -0400 Subject: [PATCH 057/221] mcp: reject notifications with unexpected ID field (#204) Use the new 'isRequest' field on methodInfo to also reject notifications with an unexpected ID field. For #196 --- mcp/shared.go | 5 ++++- .../{missing_fields.txtar => bad_requests.txtar} | 10 ++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) rename mcp/testdata/conformance/server/{missing_fields.txtar => bad_requests.txtar} (78%) diff --git a/mcp/shared.go b/mcp/shared.go index 0bc7c793..46720eed 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -155,7 +155,10 @@ func checkRequest(req *jsonrpc.Request, infos map[string]methodInfo) (methodInfo return methodInfo{}, fmt.Errorf("%w: %q unsupported", jsonrpc2.ErrNotHandled, req.Method) } if info.isRequest && !req.ID.IsValid() { - return methodInfo{}, fmt.Errorf("%w: %q missing ID", jsonrpc2.ErrInvalidRequest, req.Method) + return methodInfo{}, fmt.Errorf("%w: missing ID, %q", jsonrpc2.ErrInvalidRequest, req.Method) + } + if !info.isRequest && req.ID.IsValid() { + return methodInfo{}, fmt.Errorf("%w: unexpected id for %q", jsonrpc2.ErrInvalidRequest, req.Method) } return info, nil } diff --git a/mcp/testdata/conformance/server/missing_fields.txtar b/mcp/testdata/conformance/server/bad_requests.txtar similarity index 78% rename from mcp/testdata/conformance/server/missing_fields.txtar rename to mcp/testdata/conformance/server/bad_requests.txtar index c5aa3e55..d2f278bc 100644 --- a/mcp/testdata/conformance/server/missing_fields.txtar +++ b/mcp/testdata/conformance/server/bad_requests.txtar @@ -4,6 +4,7 @@ bad requests. Fixed bugs: - No id in 'initialize' should not panic (#197). - No id in 'ping' should not panic (#194). +- Notifications with IDs should not be treated like requests. TODO: - No params in 'initialize' should not panic (#195). @@ -31,6 +32,7 @@ code_review "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } } } +{"jsonrpc":"2.0", "id": 3, "method":"notifications/initialized"} {"jsonrpc":"2.0", "method":"ping"} -- server -- @@ -52,3 +54,11 @@ code_review } } } +{ + "jsonrpc": "2.0", + "id": 3, + "error": { + "code": -32600, + "message": "JSON RPC invalid request: unexpected id for \"notifications/initialized\"" + } +} From a9a503f74ac2b2bf552988006bfea5dabfd0ea1a Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Wed, 30 Jul 2025 17:00:50 -0400 Subject: [PATCH 058/221] mcp/streamable: add persistent SSE GET listener (#206) This CL adds the optional persistent SSE GET listener as specified in section 2.2. https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server This enables server initiated SSE streams. For #10 --- mcp/streamable.go | 33 ++++++++++++++++++++++----------- mcp/streamable_test.go | 39 +++++++++++++++++++++++++++++++++++---- 2 files changed, 57 insertions(+), 15 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index f12c3011..3f53a689 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -743,6 +743,12 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er ctx: connCtx, cancel: cancel, } + // Start the persistent SSE listener right away. + // Section 2.2: The client MAY issue an HTTP GET to the MCP endpoint. + // This can be used to open an SSE stream, allowing the server to + // communicate to the client, without the client first sending data via HTTP POST. + go conn.handleSSE(nil, true) + return conn, nil } @@ -867,7 +873,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string switch ct := resp.Header.Get("Content-Type"); ct { case "text/event-stream": // Section 2.1: The SSE stream is initiated after a POST. - go s.handleSSE(resp) + go s.handleSSE(resp, false) case "application/json": body, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -887,13 +893,11 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string return sessionID, nil } -// handleSSE manages the entire lifecycle of an SSE connection. It processes -// an incoming Server-Sent Events stream and automatically handles reconnection -// logic if the stream breaks. -func (s *streamableClientConn) handleSSE(initialResp *http.Response) { +// handleSSE manages the lifecycle of an SSE connection. It can be either +// persistent (for the main GET listener) or temporary (for a POST response). +func (s *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool) { resp := initialResp var lastEventID string - for { eventID, clientClosed := s.processStream(resp) lastEventID = eventID @@ -902,6 +906,11 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response) { if clientClosed { return } + // If the stream has ended, then do not reconnect if the stream is + // temporary (POST initiated SSE). + if lastEventID == "" && !persistent { + return + } // The stream was interrupted or ended by the server. Attempt to reconnect. newResp, err := s.reconnect(lastEventID) @@ -923,9 +932,13 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response) { // processStream reads from a single response body, sending events to the // incoming channel. It returns the ID of the last processed event, any error // that occurred, and a flag indicating if the connection was closed by the client. +// If resp is nil, it returns "", false. func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) { - defer resp.Body.Close() + if resp == nil { + return "", false + } + defer resp.Body.Close() for evt, err := range scanEvents(resp.Body) { if err != nil { return lastEventID, false @@ -939,13 +952,11 @@ func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID s case s.incoming <- evt.Data: case <-s.done: // The connection was closed by the client; exit gracefully. - return lastEventID, true + return "", true } } - // The loop finished without an error, indicating the server closed the stream. - // We'll attempt to reconnect, so this is not a client-side close. - return lastEventID, false + return "", false } // reconnect handles the logic of retrying a connection with an exponential diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 1ba20bc1..81af2760 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -168,7 +168,7 @@ func TestClientReplay(t *testing.T) { clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"}) // 4. Read and verify messages until the server signals it's ready for the proxy kill. - receivedNotifications := readProgressNotifications(t, ctx, notifications, 2) + receivedNotifications := readNotifications(t, ctx, notifications, 2) wantReceived := []string{"msg1", "msg2"} if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" { @@ -201,7 +201,7 @@ func TestClientReplay(t *testing.T) { // 7. Continue reading from the same connection object. // Its internal logic should successfully retry, reconnect to the new proxy, // and receive the replayed messages. - recoveredNotifications := readProgressNotifications(t, ctx, notifications, 2) + recoveredNotifications := readNotifications(t, ctx, notifications, 2) // 8. Verify the correct messages were received on the recovered connection. wantRecovered := []string{"msg3", "msg4"} @@ -211,8 +211,39 @@ func TestClientReplay(t *testing.T) { } } -// Helper to read a specific number of progress notifications. -func readProgressNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string { +// TestServerInitiatedSSE verifies that the persistent SSE connection remains +// open and can receive server-initiated events. +func TestServerInitiatedSSE(t *testing.T) { + notifications := make(chan string) + server := NewServer(testImpl, nil) + + httpServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) + defer httpServer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + client := NewClient(testImpl, &ClientOptions{ToolListChangedHandler: func(ctx context.Context, cc *ClientSession, params *ToolListChangedParams) { + notifications <- "toolListChanged" + }, + }) + clientSession, err := client.Connect(ctx, NewStreamableClientTransport(httpServer.URL, nil)) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer clientSession.Close() + server.AddTool(&Tool{Name: "testTool", InputSchema: &jsonschema.Schema{}}, + func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { + return &CallToolResult{}, nil + }) + receivedNotifications := readNotifications(t, ctx, notifications, 1) + wantReceived := []string{"toolListChanged"} + if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" { + t.Errorf("Received notifications mismatch (-want +got):\n%s", diff) + } +} + +// Helper to read a specific number of notifications. +func readNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string { t.Helper() var collectedNotifications []string for { From b34ba2149ddefc5ea0318240fc096d032b472216 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Thu, 31 Jul 2025 10:57:59 -0400 Subject: [PATCH 059/221] mcp: handle bad or missing params (#210) Audit all cases where params must not be null, and enforce this using methodInfo via a new methodFlags bitfield that parameterizes method requirements. Write extensive conformance tests catching all the (server-side) crashes that were possible. We should go further and validate schema against the spec, but that is more indirect, more complicated, and potentially slower. For now we adopt this more explicit approach. Still TODO in a subsequent CL: verify the client side of this with client conformance tests. Additionally, improve some error messages that were redundant or leaked internal implementation details. Fixes #195 --- internal/jsonrpc2/wire.go | 18 +++---- mcp/client.go | 24 +++++---- mcp/server.go | 39 ++++++++------ mcp/shared.go | 45 ++++++++++++---- mcp/sse_test.go | 2 +- .../conformance/server/bad_requests.txtar | 51 +++++++++++++++++-- mcp/testdata/conformance/server/prompts.txtar | 12 ++++- .../conformance/server/resources.txtar | 21 ++++++++ mcp/testdata/conformance/server/tools.txtar | 10 ++++ 9 files changed, 170 insertions(+), 52 deletions(-) diff --git a/internal/jsonrpc2/wire.go b/internal/jsonrpc2/wire.go index 309b8002..b143dcd3 100644 --- a/internal/jsonrpc2/wire.go +++ b/internal/jsonrpc2/wire.go @@ -13,30 +13,30 @@ import ( var ( // ErrParse is used when invalid JSON was received by the server. - ErrParse = NewError(-32700, "JSON RPC parse error") + ErrParse = NewError(-32700, "parse error") // ErrInvalidRequest is used when the JSON sent is not a valid Request object. - ErrInvalidRequest = NewError(-32600, "JSON RPC invalid request") + ErrInvalidRequest = NewError(-32600, "invalid request") // ErrMethodNotFound should be returned by the handler when the method does // not exist / is not available. - ErrMethodNotFound = NewError(-32601, "JSON RPC method not found") + ErrMethodNotFound = NewError(-32601, "method not found") // ErrInvalidParams should be returned by the handler when method // parameter(s) were invalid. - ErrInvalidParams = NewError(-32602, "JSON RPC invalid params") + ErrInvalidParams = NewError(-32602, "invalid params") // ErrInternal indicates a failure to process a call correctly - ErrInternal = NewError(-32603, "JSON RPC internal error") + ErrInternal = NewError(-32603, "internal error") // The following errors are not part of the json specification, but // compliant extensions specific to this implementation. // ErrServerOverloaded is returned when a message was refused due to a // server being temporarily unable to accept any new messages. - ErrServerOverloaded = NewError(-32000, "JSON RPC overloaded") + ErrServerOverloaded = NewError(-32000, "overloaded") // ErrUnknown should be used for all non coded errors. - ErrUnknown = NewError(-32001, "JSON RPC unknown error") + ErrUnknown = NewError(-32001, "unknown error") // ErrServerClosing is returned for calls that arrive while the server is closing. - ErrServerClosing = NewError(-32004, "JSON RPC server is closing") + ErrServerClosing = NewError(-32004, "server is closing") // ErrClientClosing is a dummy error returned for calls initiated while the client is closing. - ErrClientClosing = NewError(-32003, "JSON RPC client is closing") + ErrClientClosing = NewError(-32003, "client is closing") ) const wireVersion = "2.0" diff --git a/mcp/client.go b/mcp/client.go index 2cd81e04..9e3f3935 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -286,17 +286,21 @@ func (c *Client) AddReceivingMiddleware(middleware ...Middleware[*ClientSession] } // clientMethodInfos maps from the RPC method name to serverMethodInfos. +// +// The 'allowMissingParams' values are extracted from the protocol schema. +// TODO(rfindley): actually load and validate the protocol schema, rather than +// curating these method flags. var clientMethodInfos = map[string]methodInfo{ - methodComplete: newMethodInfo(sessionMethod((*ClientSession).Complete), true), - methodPing: newMethodInfo(sessionMethod((*ClientSession).ping), true), - methodListRoots: newMethodInfo(clientMethod((*Client).listRoots), true), - methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage), true), - notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler), false), - notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler), false), - notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler), false), - notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), false), - notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler), false), - notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler), false), + methodComplete: newMethodInfo(sessionMethod((*ClientSession).Complete), 0), + methodPing: newMethodInfo(sessionMethod((*ClientSession).ping), missingParamsOK), + methodListRoots: newMethodInfo(clientMethod((*Client).listRoots), missingParamsOK), + methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage), 0), + notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK), + notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK), + notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler), notification|missingParamsOK), + notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), notification|missingParamsOK), + notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler), notification), + notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler), notification), } func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { diff --git a/mcp/server.go b/mcp/server.go index 220c28d3..08effd79 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -683,23 +683,27 @@ func (s *Server) AddReceivingMiddleware(middleware ...Middleware[*ServerSession] } // serverMethodInfos maps from the RPC method name to serverMethodInfos. +// +// The 'allowMissingParams' values are extracted from the protocol schema. +// TODO(rfindley): actually load and validate the protocol schema, rather than +// curating these method flags. var serverMethodInfos = map[string]methodInfo{ - methodComplete: newMethodInfo(serverMethod((*Server).complete), true), - methodInitialize: newMethodInfo(sessionMethod((*ServerSession).initialize), true), - methodPing: newMethodInfo(sessionMethod((*ServerSession).ping), true), - methodListPrompts: newMethodInfo(serverMethod((*Server).listPrompts), true), - methodGetPrompt: newMethodInfo(serverMethod((*Server).getPrompt), true), - methodListTools: newMethodInfo(serverMethod((*Server).listTools), true), - methodCallTool: newMethodInfo(serverMethod((*Server).callTool), true), - methodListResources: newMethodInfo(serverMethod((*Server).listResources), true), - methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates), true), - methodReadResource: newMethodInfo(serverMethod((*Server).readResource), true), - methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel), true), - methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe), true), - methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe), true), - notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler), false), - notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler), false), - notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler), false), + methodComplete: newMethodInfo(serverMethod((*Server).complete), 0), + methodInitialize: newMethodInfo(sessionMethod((*ServerSession).initialize), 0), + methodPing: newMethodInfo(sessionMethod((*ServerSession).ping), missingParamsOK), + methodListPrompts: newMethodInfo(serverMethod((*Server).listPrompts), missingParamsOK), + methodGetPrompt: newMethodInfo(serverMethod((*Server).getPrompt), 0), + methodListTools: newMethodInfo(serverMethod((*Server).listTools), missingParamsOK), + methodCallTool: newMethodInfo(serverMethod((*Server).callTool), 0), + methodListResources: newMethodInfo(serverMethod((*Server).listResources), missingParamsOK), + methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates), missingParamsOK), + methodReadResource: newMethodInfo(serverMethod((*Server).readResource), 0), + methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel), 0), + methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe), 0), + methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe), 0), + notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler), notification|missingParamsOK), + notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK), + notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler), notification), } func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos } @@ -744,6 +748,9 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, } func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) { + if params == nil { + return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) + } ss.mu.Lock() ss.initializeParams = params ss.mu.Unlock() diff --git a/mcp/shared.go b/mcp/shared.go index 46720eed..54038f5f 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -129,7 +129,7 @@ func handleReceive[S Session](ctx context.Context, session S, req *jsonrpc.Reque } params, err := info.unmarshalParams(req.Params) if err != nil { - return nil, fmt.Errorf("handleRequest %q: %w", req.Method, err) + return nil, fmt.Errorf("handling '%s': %w", req.Method, err) } mh := session.receivingMethodHandler().(MethodHandler[S]) @@ -154,20 +154,28 @@ func checkRequest(req *jsonrpc.Request, infos map[string]methodInfo) (methodInfo if !ok { return methodInfo{}, fmt.Errorf("%w: %q unsupported", jsonrpc2.ErrNotHandled, req.Method) } - if info.isRequest && !req.ID.IsValid() { - return methodInfo{}, fmt.Errorf("%w: missing ID, %q", jsonrpc2.ErrInvalidRequest, req.Method) - } - if !info.isRequest && req.ID.IsValid() { + if info.flags¬ification != 0 && req.ID.IsValid() { return methodInfo{}, fmt.Errorf("%w: unexpected id for %q", jsonrpc2.ErrInvalidRequest, req.Method) } + if info.flags¬ification == 0 && !req.ID.IsValid() { + return methodInfo{}, fmt.Errorf("%w: missing id for %q", jsonrpc2.ErrInvalidRequest, req.Method) + } + // missingParamsOK is checked here to catch the common case where "params" is + // missing entirely. + // + // However, it's checked again after unmarshalling to catch the rare but + // possible case where "params" is JSON null (see https://go.dev/issue/33835). + if info.flags&missingParamsOK == 0 && len(req.Params) == 0 { + return methodInfo{}, fmt.Errorf("%w: missing required \"params\"", jsonrpc2.ErrInvalidRequest) + } return info, nil } // methodInfo is information about sending and receiving a method. type methodInfo struct { - // isRequest reports whether the method is a JSON-RPC request. - // Otherwise, the method is treated as a notification. - isRequest bool + // flags is a collection of flags controlling how the JSONRPC method is + // handled. See individual flag values for documentation. + flags methodFlags // Unmarshal params from the wire into a Params struct. // Used on the receive side. unmarshalParams func(json.RawMessage) (Params, error) @@ -193,13 +201,20 @@ type paramsPtr[T any] interface { Params } +type methodFlags int + +const ( + notification methodFlags = 1 << iota // method is a notification, not request + missingParamsOK // params may be missing or null +) + // newMethodInfo creates a methodInfo from a typedMethodHandler. // // If isRequest is set, the method is treated as a request rather than a // notification. -func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHandler[S, P, R], isRequest bool) methodInfo { +func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHandler[S, P, R], flags methodFlags) methodInfo { return methodInfo{ - isRequest: isRequest, + flags: flags, unmarshalParams: func(m json.RawMessage) (Params, error) { var p P if m != nil { @@ -207,6 +222,16 @@ func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHand return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, p, err) } } + // We must check missingParamsOK here, in addition to checkRequest, to + // catch the edge cases where "params" is set to JSON null. + // See also https://go.dev/issue/33835. + // + // We need to ensure that p is non-null to guard against crashes, as our + // internal code or externally provided handlers may assume that params + // is non-null. + if flags&missingParamsOK == 0 && p == nil { + return nil, fmt.Errorf("%w: missing required \"params\"", jsonrpc2.ErrInvalidRequest) + } return orZero[Params](p), nil }, handleMethod: MethodHandler[S](func(ctx context.Context, session S, _ string, params Params) (Result, error) { diff --git a/mcp/sse_test.go b/mcp/sse_test.go index d486732c..35fdbdbf 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -88,7 +88,7 @@ func TestSSEServer(t *testing.T) { responseContains string }{ {"not a method", `{"jsonrpc":"2.0", "method":"notamethod"}`, "not handled"}, - {"missing ID", `{"jsonrpc":"2.0", "method":"ping"}`, "missing ID"}, + {"missing ID", `{"jsonrpc":"2.0", "method":"ping"}`, "missing id"}, } for _, r := range badRequests { t.Run(r.name, func(t *testing.T) { diff --git a/mcp/testdata/conformance/server/bad_requests.txtar b/mcp/testdata/conformance/server/bad_requests.txtar index d2f278bc..a8767ad2 100644 --- a/mcp/testdata/conformance/server/bad_requests.txtar +++ b/mcp/testdata/conformance/server/bad_requests.txtar @@ -4,10 +4,11 @@ bad requests. Fixed bugs: - No id in 'initialize' should not panic (#197). - No id in 'ping' should not panic (#194). -- Notifications with IDs should not be treated like requests. - -TODO: - No params in 'initialize' should not panic (#195). +- Notifications with IDs should not be treated like requests. (#196) +- No params in 'logging/setLevel' should not panic. +- No params in 'completion/complete' should not panic. +- JSON null params should also not panic in these cases. -- prompts -- code_review @@ -22,6 +23,11 @@ code_review "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } } } +{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize" +} { "jsonrpc": "2.0", "id": 2, @@ -32,10 +38,21 @@ code_review "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } } } -{"jsonrpc":"2.0", "id": 3, "method":"notifications/initialized"} +{"jsonrpc":"2.0", "id": 3, "method": "notifications/initialized"} {"jsonrpc":"2.0", "method":"ping"} +{"jsonrpc":"2.0", "id": 4, "method": "logging/setLevel"} +{"jsonrpc":"2.0", "id": 5, "method": "completion/complete"} +{"jsonrpc":"2.0", "id": 4, "method": "logging/setLevel", "params": null} -- server -- +{ + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32600, + "message": "invalid request: missing required \"params\"" + } +} { "jsonrpc": "2.0", "id": 2, @@ -59,6 +76,30 @@ code_review "id": 3, "error": { "code": -32600, - "message": "JSON RPC invalid request: unexpected id for \"notifications/initialized\"" + "message": "invalid request: unexpected id for \"notifications/initialized\"" + } +} +{ + "jsonrpc": "2.0", + "id": 4, + "error": { + "code": -32600, + "message": "invalid request: missing required \"params\"" + } +} +{ + "jsonrpc": "2.0", + "id": 5, + "error": { + "code": -32600, + "message": "invalid request: missing required \"params\"" + } +} +{ + "jsonrpc": "2.0", + "id": 4, + "error": { + "code": -32600, + "message": "handling 'logging/setLevel': invalid request: missing required \"params\"" } } diff --git a/mcp/testdata/conformance/server/prompts.txtar b/mcp/testdata/conformance/server/prompts.txtar index 078e0915..8eb678cc 100644 --- a/mcp/testdata/conformance/server/prompts.txtar +++ b/mcp/testdata/conformance/server/prompts.txtar @@ -1,7 +1,8 @@ Check behavior of a server with just prompts. Fixed bugs: -- empty tools lists should not be returned as 'null' +- Empty tools lists should not be returned as 'null'. +- No params in 'prompts/get' should not panic. -- prompts -- code_review @@ -19,6 +20,7 @@ code_review } { "jsonrpc": "2.0", "id": 2, "method": "tools/list" } { "jsonrpc": "2.0", "id": 4, "method": "prompts/list" } +{ "jsonrpc": "2.0", "id": 5, "method": "prompts/get" } -- server -- { "jsonrpc": "2.0", @@ -63,3 +65,11 @@ code_review ] } } +{ + "jsonrpc": "2.0", + "id": 5, + "error": { + "code": -32600, + "message": "invalid request: missing required \"params\"" + } +} diff --git a/mcp/testdata/conformance/server/resources.txtar b/mcp/testdata/conformance/server/resources.txtar index 5bb5515d..fe9a1b78 100644 --- a/mcp/testdata/conformance/server/resources.txtar +++ b/mcp/testdata/conformance/server/resources.txtar @@ -2,6 +2,9 @@ Check behavior of a server with just resources. Fixed bugs: - A resource result holds a slice of contents, not just one. +- No params in 'resource/read' should not panic. +- No params in 'resources/subscribe' should not panic. +- No params in 'resources/unsubscribe' should not panic. -- resources -- info @@ -39,6 +42,8 @@ info.txt "roots": [] } } +{ "jsonrpc": "2.0", "id": 4, "method": "resources/read" } +{ "jsonrpc": "2.0", "id": 5, "method": "resources/subscribe" } -- server -- { "jsonrpc": "2.0", @@ -107,3 +112,19 @@ info.txt ] } } +{ + "jsonrpc": "2.0", + "id": 4, + "error": { + "code": -32600, + "message": "invalid request: missing required \"params\"" + } +} +{ + "jsonrpc": "2.0", + "id": 5, + "error": { + "code": -32600, + "message": "invalid request: missing required \"params\"" + } +} diff --git a/mcp/testdata/conformance/server/tools.txtar b/mcp/testdata/conformance/server/tools.txtar index 07ad942a..7525ac15 100644 --- a/mcp/testdata/conformance/server/tools.txtar +++ b/mcp/testdata/conformance/server/tools.txtar @@ -4,6 +4,7 @@ Fixed bugs: - "tools/list" can have missing params - "_meta" should not be nil - empty resource or prompts should not be returned as 'null' +- the server should not crash when params are passed to tools/call -- tools -- greet @@ -22,6 +23,7 @@ greet { "jsonrpc": "2.0", "id": 2, "method": "tools/list" } { "jsonrpc": "2.0", "id": 3, "method": "resources/list" } { "jsonrpc": "2.0", "id": 4, "method": "prompts/list" } +{ "jsonrpc": "2.0", "id": 5, "method": "tools/call" } -- server -- { "jsonrpc": "2.0", @@ -81,3 +83,11 @@ greet "prompts": [] } } +{ + "jsonrpc": "2.0", + "id": 5, + "error": { + "code": -32600, + "message": "invalid request: missing required \"params\"" + } +} From 0cde9623d320a5b9bab251b0e8d053bd31ab7adf Mon Sep 17 00:00:00 2001 From: cryo Date: Sat, 2 Aug 2025 04:20:55 +0800 Subject: [PATCH 060/221] mcp: fix comments (#221) ### change: Fix typos in some comments. --- mcp/server.go | 2 +- mcp/shared.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 08effd79..b96edbe8 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -215,7 +215,7 @@ func (s *Server) RemoveResources(uris ...string) { func() bool { return s.resources.remove(uris...) }) } -// AddResourceTemplate adds a [ResourceTemplate] to the server, or replaces on with the same URI. +// AddResourceTemplate adds a [ResourceTemplate] to the server, or replaces one with the same URI. // AddResourceTemplate panics if a URI template is invalid or not absolute (has an empty scheme). func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) { s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, diff --git a/mcp/shared.go b/mcp/shared.go index 54038f5f..319071f2 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -86,7 +86,7 @@ func defaultSendingMethodHandler[S Session](ctx context.Context, session S, meth return res, nil } -// Helper methods to avoid typed nil. +// Helper method to avoid typed nil. func orZero[T any, P *U, U any](p P) T { if p == nil { var zero T From faaf77dc0557c76add88ba4ea4c066b58c16d199 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Mon, 4 Aug 2025 10:12:01 -0400 Subject: [PATCH 061/221] mcp/server: advertise completions only if installed (#213) Previously, completions would always be advertised as a capability even if the CompletionHandler was not installed by the server. This CL fixes that. --- mcp/server.go | 7 +-- mcp/server_test.go | 51 +++++++++++-------- mcp/streamable_test.go | 10 ++-- .../conformance/server/bad_requests.txtar | 1 - mcp/testdata/conformance/server/prompts.txtar | 1 - .../conformance/server/resources.txtar | 1 - mcp/testdata/conformance/server/tools.txtar | 1 - .../conformance/server/version-latest.txtar | 1 - .../conformance/server/version-older.txtar | 1 - 9 files changed, 38 insertions(+), 36 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index b96edbe8..ba2bd0a9 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -238,9 +238,7 @@ func (s *Server) capabilities() *serverCapabilities { defer s.mu.Unlock() caps := &serverCapabilities{ - // TODO(samthanawalla): check for completionHandler before advertising capability. - Completions: &completionCapabilities{}, - Logging: &loggingCapabilities{}, + Logging: &loggingCapabilities{}, } if s.opts.HasTools || s.tools.len() > 0 { caps.Tools = &toolCapabilities{ListChanged: true} @@ -254,6 +252,9 @@ func (s *Server) capabilities() *serverCapabilities { caps.Resources.Subscribe = true } } + if s.opts.CompletionHandler != nil { + caps.Completions = &completionCapabilities{} + } return caps } diff --git a/mcp/server_test.go b/mcp/server_test.go index 79c607e9..0b853a33 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -242,8 +242,7 @@ func TestServerCapabilities(t *testing.T) { name: "No capabilities", configureServer: func(s *Server) {}, wantCapabilities: &serverCapabilities{ - Completions: &completionCapabilities{}, - Logging: &loggingCapabilities{}, + Logging: &loggingCapabilities{}, }, }, { @@ -252,9 +251,8 @@ func TestServerCapabilities(t *testing.T) { s.AddPrompt(&Prompt{Name: "p"}, nil) }, wantCapabilities: &serverCapabilities{ - Completions: &completionCapabilities{}, - Logging: &loggingCapabilities{}, - Prompts: &promptCapabilities{ListChanged: true}, + Logging: &loggingCapabilities{}, + Prompts: &promptCapabilities{ListChanged: true}, }, }, { @@ -263,9 +261,8 @@ func TestServerCapabilities(t *testing.T) { s.AddResource(&Resource{URI: "file:///r"}, nil) }, wantCapabilities: &serverCapabilities{ - Completions: &completionCapabilities{}, - Logging: &loggingCapabilities{}, - Resources: &resourceCapabilities{ListChanged: true}, + Logging: &loggingCapabilities{}, + Resources: &resourceCapabilities{ListChanged: true}, }, }, { @@ -274,9 +271,8 @@ func TestServerCapabilities(t *testing.T) { s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) }, wantCapabilities: &serverCapabilities{ - Completions: &completionCapabilities{}, - Logging: &loggingCapabilities{}, - Resources: &resourceCapabilities{ListChanged: true}, + Logging: &loggingCapabilities{}, + Resources: &resourceCapabilities{ListChanged: true}, }, }, { @@ -293,9 +289,8 @@ func TestServerCapabilities(t *testing.T) { }, }, wantCapabilities: &serverCapabilities{ - Completions: &completionCapabilities{}, - Logging: &loggingCapabilities{}, - Resources: &resourceCapabilities{ListChanged: true, Subscribe: true}, + Logging: &loggingCapabilities{}, + Resources: &resourceCapabilities{ListChanged: true, Subscribe: true}, }, }, { @@ -304,9 +299,21 @@ func TestServerCapabilities(t *testing.T) { s.AddTool(tool, nil) }, wantCapabilities: &serverCapabilities{ - Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Tools: &toolCapabilities{ListChanged: true}, + }, + }, + { + name: "With completions", + configureServer: func(s *Server) {}, + serverOpts: ServerOptions{ + CompletionHandler: func(ctx context.Context, ss *ServerSession, params *CompleteParams) (*CompleteResult, error) { + return nil, nil + }, + }, + wantCapabilities: &serverCapabilities{ Logging: &loggingCapabilities{}, - Tools: &toolCapabilities{ListChanged: true}, + Completions: &completionCapabilities{}, }, }, { @@ -324,6 +331,9 @@ func TestServerCapabilities(t *testing.T) { UnsubscribeHandler: func(ctx context.Context, up *UnsubscribeParams) error { return nil }, + CompletionHandler: func(ctx context.Context, ss *ServerSession, params *CompleteParams) (*CompleteResult, error) { + return nil, nil + }, }, wantCapabilities: &serverCapabilities{ Completions: &completionCapabilities{}, @@ -342,11 +352,10 @@ func TestServerCapabilities(t *testing.T) { HasTools: true, }, wantCapabilities: &serverCapabilities{ - Completions: &completionCapabilities{}, - Logging: &loggingCapabilities{}, - Prompts: &promptCapabilities{ListChanged: true}, - Resources: &resourceCapabilities{ListChanged: true}, - Tools: &toolCapabilities{ListChanged: true}, + Logging: &loggingCapabilities{}, + Prompts: &promptCapabilities{ListChanged: true}, + Resources: &resourceCapabilities{ListChanged: true}, + Tools: &toolCapabilities{ListChanged: true}, }, }, } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 81af2760..af06bd39 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -314,9 +314,8 @@ func TestStreamableServerTransport(t *testing.T) { initReq := req(1, methodInitialize, &InitializeParams{}) initResp := resp(1, &InitializeResult{ Capabilities: &serverCapabilities{ - Completions: &completionCapabilities{}, - Logging: &loggingCapabilities{}, - Tools: &toolCapabilities{ListChanged: true}, + Logging: &loggingCapabilities{}, + Tools: &toolCapabilities{ListChanged: true}, }, ProtocolVersion: latestProtocolVersion, ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, @@ -732,9 +731,8 @@ func TestStreamableClientTransportApplicationJSON(t *testing.T) { } initResult := &InitializeResult{ Capabilities: &serverCapabilities{ - Completions: &completionCapabilities{}, - Logging: &loggingCapabilities{}, - Tools: &toolCapabilities{ListChanged: true}, + Logging: &loggingCapabilities{}, + Tools: &toolCapabilities{ListChanged: true}, }, ProtocolVersion: latestProtocolVersion, ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, diff --git a/mcp/testdata/conformance/server/bad_requests.txtar b/mcp/testdata/conformance/server/bad_requests.txtar index a8767ad2..e9e9d483 100644 --- a/mcp/testdata/conformance/server/bad_requests.txtar +++ b/mcp/testdata/conformance/server/bad_requests.txtar @@ -58,7 +58,6 @@ code_review "id": 2, "result": { "capabilities": { - "completions": {}, "logging": {}, "prompts": { "listChanged": true diff --git a/mcp/testdata/conformance/server/prompts.txtar b/mcp/testdata/conformance/server/prompts.txtar index 8eb678cc..3fd036e6 100644 --- a/mcp/testdata/conformance/server/prompts.txtar +++ b/mcp/testdata/conformance/server/prompts.txtar @@ -27,7 +27,6 @@ code_review "id": 1, "result": { "capabilities": { - "completions": {}, "logging": {}, "prompts": { "listChanged": true diff --git a/mcp/testdata/conformance/server/resources.txtar b/mcp/testdata/conformance/server/resources.txtar index fe9a1b78..ae2e23cb 100644 --- a/mcp/testdata/conformance/server/resources.txtar +++ b/mcp/testdata/conformance/server/resources.txtar @@ -50,7 +50,6 @@ info.txt "id": 1, "result": { "capabilities": { - "completions": {}, "logging": {}, "resources": { "listChanged": true diff --git a/mcp/testdata/conformance/server/tools.txtar b/mcp/testdata/conformance/server/tools.txtar index 7525ac15..b4068d1c 100644 --- a/mcp/testdata/conformance/server/tools.txtar +++ b/mcp/testdata/conformance/server/tools.txtar @@ -30,7 +30,6 @@ greet "id": 1, "result": { "capabilities": { - "completions": {}, "logging": {}, "tools": { "listChanged": true diff --git a/mcp/testdata/conformance/server/version-latest.txtar b/mcp/testdata/conformance/server/version-latest.txtar index 89454fb3..75317676 100644 --- a/mcp/testdata/conformance/server/version-latest.txtar +++ b/mcp/testdata/conformance/server/version-latest.txtar @@ -18,7 +18,6 @@ response with its latest supported version. "id": 1, "result": { "capabilities": { - "completions": {}, "logging": {} }, "protocolVersion": "2025-06-18", diff --git a/mcp/testdata/conformance/server/version-older.txtar b/mcp/testdata/conformance/server/version-older.txtar index 55240954..82292630 100644 --- a/mcp/testdata/conformance/server/version-older.txtar +++ b/mcp/testdata/conformance/server/version-older.txtar @@ -18,7 +18,6 @@ support. "id": 1, "result": { "capabilities": { - "completions": {}, "logging": {} }, "protocolVersion": "2024-11-05", From 1e40ede369321e9fbae0314905cec3fcae45e928 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 4 Aug 2025 10:50:57 -0400 Subject: [PATCH 062/221] jsonschema: options for inference (#185) - Add ForOptions to hold options for schema inference. - Replace ForLax with ForOptions.IgnoreBadTypes. - Add an option to provide schemas for arbitrary types. - Provide a default mapping from types to schemas that includes stdlib types with MarshalJSON methods. - Add Schema.CloneSchemas. This is needed to make copies of the schemas in the above map: a schema cannot appear twice in a parent schema, because schema addresses matter when resolving internal references. --- jsonschema/infer.go | 99 ++++++++++++++++++++++++++------------- jsonschema/infer_test.go | 66 +++++++++++++++----------- jsonschema/schema.go | 36 ++++++++++++++ jsonschema/schema_test.go | 31 ++++++++++++ mcp/tool.go | 2 +- 5 files changed, 173 insertions(+), 61 deletions(-) diff --git a/jsonschema/infer.go b/jsonschema/infer.go index 9ff0ddd5..7b6b7e2b 100644 --- a/jsonschema/infer.go +++ b/jsonschema/infer.go @@ -8,15 +8,40 @@ package jsonschema import ( "fmt" + "log/slog" + "math/big" "reflect" "regexp" + "time" "github.com/modelcontextprotocol/go-sdk/internal/util" ) +// ForOptions are options for the [For] function. +type ForOptions struct { + // If IgnoreInvalidTypes is true, fields that can't be represented as a JSON Schema + // are ignored instead of causing an error. + // This allows callers to adjust the resulting schema using custom knowledge. + // For example, an interface type where all the possible implementations are + // known can be described with "oneof". + IgnoreInvalidTypes bool + + // TypeSchemas maps types to their schemas. + // If [For] encounters a type equal to a type of a key in this map, the + // corresponding value is used as the resulting schema (after cloning to + // ensure uniqueness). + // Types in this map override the default translations, as described + // in [For]'s documentation. + TypeSchemas map[any]*Schema +} + // For constructs a JSON schema object for the given type argument. +// If non-nil, the provided options configure certain aspects of this contruction, +// described below. + +// It translates Go types into compatible JSON schema types, as follows. +// These defaults can be overridden by [ForOptions.TypeSchemas]. // -// It translates Go types into compatible JSON schema types, as follows: // - Strings have schema type "string". // - Bools have schema type "boolean". // - Signed and unsigned integer types have schema type "integer". @@ -29,48 +54,51 @@ import ( // Their properties are derived from exported struct fields, using the // struct field JSON name. Fields that are marked "omitempty" are // considered optional; all other fields become required properties. +// - Some types in the standard library that implement json.Marshaler +// translate to schemas that match the values to which they marshal. +// For example, [time.Time] translates to the schema for strings. +// +// For will return an error if there is a cycle in the types. // -// For returns an error if t contains (possibly recursively) any of the following Go -// types, as they are incompatible with the JSON schema spec. +// By default, For returns an error if t contains (possibly recursively) any of the +// following Go types, as they are incompatible with the JSON schema spec. +// If [ForOptions.IgnoreInvalidTypes] is true, then these types are ignored instead. // - maps with key other than 'string' // - function types // - channel types // - complex numbers // - unsafe pointers // -// It will return an error if there is a cycle in the types. -// // This function recognizes struct field tags named "jsonschema". // A jsonschema tag on a field is used as the description for the corresponding property. // For future compatibility, descriptions must not start with "WORD=", where WORD is a // sequence of non-whitespace characters. -func For[T any]() (*Schema, error) { - // TODO: consider skipping incompatible fields, instead of failing. - seen := make(map[reflect.Type]bool) - s, err := forType(reflect.TypeFor[T](), seen, false) - if err != nil { - var z T - return nil, fmt.Errorf("For[%T](): %w", z, err) +func For[T any](opts *ForOptions) (*Schema, error) { + if opts == nil { + opts = &ForOptions{} } - return s, nil -} - -// ForLax behaves like [For], except that it ignores struct fields with invalid types instead of -// returning an error. That allows callers to adjust the resulting schema using custom knowledge. -// For example, an interface type where all the possible implementations are known -// can be described with "oneof". -func ForLax[T any]() (*Schema, error) { - // TODO: consider skipping incompatible fields, instead of failing. - seen := make(map[reflect.Type]bool) - s, err := forType(reflect.TypeFor[T](), seen, true) + schemas := make(map[reflect.Type]*Schema) + // Add types from the standard library that have MarshalJSON methods. + ss := &Schema{Type: "string"} + schemas[reflect.TypeFor[time.Time]()] = ss + schemas[reflect.TypeFor[slog.Level]()] = ss + schemas[reflect.TypeFor[big.Int]()] = &Schema{Types: []string{"null", "string"}} + schemas[reflect.TypeFor[big.Rat]()] = ss + schemas[reflect.TypeFor[big.Float]()] = ss + + // Add types from the options. They override the default ones. + for v, s := range opts.TypeSchemas { + schemas[reflect.TypeOf(v)] = s + } + s, err := forType(reflect.TypeFor[T](), map[reflect.Type]bool{}, opts.IgnoreInvalidTypes, schemas) if err != nil { var z T - return nil, fmt.Errorf("ForLax[%T](): %w", z, err) + return nil, fmt.Errorf("For[%T](): %w", z, err) } return s, nil } -func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, error) { +func forType(t reflect.Type, seen map[reflect.Type]bool, ignore bool, schemas map[reflect.Type]*Schema) (*Schema, error) { // Follow pointers: the schema for *T is almost the same as for T, except that // an explicit JSON "null" is allowed for the pointer. allowNull := false @@ -89,6 +117,10 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err defer delete(seen, t) } + if s := schemas[t]; s != nil { + return s.CloneSchemas(), nil + } + var ( s = new(Schema) err error @@ -111,7 +143,7 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err case reflect.Map: if t.Key().Kind() != reflect.String { - if lax { + if ignore { return nil, nil // ignore } return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind()) @@ -119,22 +151,22 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err if t.Key().Kind() != reflect.String { } s.Type = "object" - s.AdditionalProperties, err = forType(t.Elem(), seen, lax) + s.AdditionalProperties, err = forType(t.Elem(), seen, ignore, schemas) if err != nil { return nil, fmt.Errorf("computing map value schema: %v", err) } - if lax && s.AdditionalProperties == nil { + if ignore && s.AdditionalProperties == nil { // Ignore if the element type is invalid. return nil, nil } case reflect.Slice, reflect.Array: s.Type = "array" - s.Items, err = forType(t.Elem(), seen, lax) + s.Items, err = forType(t.Elem(), seen, ignore, schemas) if err != nil { return nil, fmt.Errorf("computing element schema: %v", err) } - if lax && s.Items == nil { + if ignore && s.Items == nil { // Ignore if the element type is invalid. return nil, nil } @@ -160,11 +192,11 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err if s.Properties == nil { s.Properties = make(map[string]*Schema) } - fs, err := forType(field.Type, seen, lax) + fs, err := forType(field.Type, seen, ignore, schemas) if err != nil { return nil, err } - if lax && fs == nil { + if ignore && fs == nil { // Skip fields of invalid type. continue } @@ -184,7 +216,7 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err } default: - if lax { + if ignore { // Ignore. return nil, nil } @@ -194,6 +226,7 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err s.Types = []string{"null", s.Type} s.Type = "" } + schemas[t] = s return s, nil } diff --git a/jsonschema/infer_test.go b/jsonschema/infer_test.go index 8c8feec0..1a0895b4 100644 --- a/jsonschema/infer_test.go +++ b/jsonschema/infer_test.go @@ -5,22 +5,30 @@ package jsonschema_test import ( + "log/slog" + "math/big" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/modelcontextprotocol/go-sdk/jsonschema" ) -func forType[T any](lax bool) *jsonschema.Schema { +type custom int + +func forType[T any](ignore bool) *jsonschema.Schema { var s *jsonschema.Schema var err error - if lax { - s, err = jsonschema.ForLax[T]() - } else { - s, err = jsonschema.For[T]() + + opts := &jsonschema.ForOptions{ + IgnoreInvalidTypes: ignore, + TypeSchemas: map[any]*jsonschema.Schema{ + custom(0): {Type: "custom"}, + }, } + s, err = jsonschema.For[T](opts) if err != nil { panic(err) } @@ -40,19 +48,23 @@ func TestFor(t *testing.T) { want *jsonschema.Schema } - tests := func(lax bool) []test { + tests := func(ignore bool) []test { return []test{ - {"string", forType[string](lax), &schema{Type: "string"}}, - {"int", forType[int](lax), &schema{Type: "integer"}}, - {"int16", forType[int16](lax), &schema{Type: "integer"}}, - {"uint32", forType[int16](lax), &schema{Type: "integer"}}, - {"float64", forType[float64](lax), &schema{Type: "number"}}, - {"bool", forType[bool](lax), &schema{Type: "boolean"}}, - {"intmap", forType[map[string]int](lax), &schema{ + {"string", forType[string](ignore), &schema{Type: "string"}}, + {"int", forType[int](ignore), &schema{Type: "integer"}}, + {"int16", forType[int16](ignore), &schema{Type: "integer"}}, + {"uint32", forType[int16](ignore), &schema{Type: "integer"}}, + {"float64", forType[float64](ignore), &schema{Type: "number"}}, + {"bool", forType[bool](ignore), &schema{Type: "boolean"}}, + {"time", forType[time.Time](ignore), &schema{Type: "string"}}, + {"level", forType[slog.Level](ignore), &schema{Type: "string"}}, + {"bigint", forType[big.Int](ignore), &schema{Types: []string{"null", "string"}}}, + {"custom", forType[custom](ignore), &schema{Type: "custom"}}, + {"intmap", forType[map[string]int](ignore), &schema{ Type: "object", AdditionalProperties: &schema{Type: "integer"}, }}, - {"anymap", forType[map[string]any](lax), &schema{ + {"anymap", forType[map[string]any](ignore), &schema{ Type: "object", AdditionalProperties: &schema{}, }}, @@ -66,7 +78,7 @@ func TestFor(t *testing.T) { NoSkip string `json:",omitempty"` unexported float64 unexported2 int `json:"No"` - }](lax), + }](ignore), &schema{ Type: "object", Properties: map[string]*schema{ @@ -81,7 +93,7 @@ func TestFor(t *testing.T) { }, { "no sharing", - forType[struct{ X, Y int }](lax), + forType[struct{ X, Y int }](ignore), &schema{ Type: "object", Properties: map[string]*schema{ @@ -97,7 +109,7 @@ func TestFor(t *testing.T) { forType[struct { A S S - }](lax), + }](ignore), &schema{ Type: "object", Properties: map[string]*schema{ @@ -165,7 +177,7 @@ func TestFor(t *testing.T) { } func forErr[T any]() error { - _, err := jsonschema.For[T]() + _, err := jsonschema.For[T](nil) return err } @@ -209,7 +221,7 @@ func TestForWithMutation(t *testing.T) { D [3]S E *bool } - s, err := jsonschema.For[T]() + s, err := jsonschema.For[T](nil) if err != nil { t.Fatalf("For: %v", err) } @@ -220,7 +232,7 @@ func TestForWithMutation(t *testing.T) { s.Properties["D"].MinItems = jsonschema.Ptr(10) s.Properties["E"].Types[0] = "mutated" - s2, err := jsonschema.For[T]() + s2, err := jsonschema.For[T](nil) if err != nil { t.Fatalf("For: %v", err) } @@ -266,13 +278,13 @@ func TestForWithCycle(t *testing.T) { shouldErr bool fn func() error }{ - {"slice alias (a)", true, func() error { _, err := jsonschema.For[a](); return err }}, - {"unexported self cycle (b1)", false, func() error { _, err := jsonschema.For[b1](); return err }}, - {"exported self cycle (b2)", true, func() error { _, err := jsonschema.For[b2](); return err }}, - {"unexported map self cycle (c1)", false, func() error { _, err := jsonschema.For[c1](); return err }}, - {"exported map self cycle (c2)", true, func() error { _, err := jsonschema.For[c2](); return err }}, - {"cross-cycle x -> y -> x", true, func() error { _, err := jsonschema.For[x](); return err }}, - {"cross-cycle y -> x -> y", true, func() error { _, err := jsonschema.For[y](); return err }}, + {"slice alias (a)", true, func() error { _, err := jsonschema.For[a](nil); return err }}, + {"unexported self cycle (b1)", false, func() error { _, err := jsonschema.For[b1](nil); return err }}, + {"exported self cycle (b2)", true, func() error { _, err := jsonschema.For[b2](nil); return err }}, + {"unexported map self cycle (c1)", false, func() error { _, err := jsonschema.For[c1](nil); return err }}, + {"exported map self cycle (c2)", true, func() error { _, err := jsonschema.For[c2](nil); return err }}, + {"cross-cycle x -> y -> x", true, func() error { _, err := jsonschema.For[x](nil); return err }}, + {"cross-cycle y -> x -> y", true, func() error { _, err := jsonschema.For[y](nil); return err }}, } for _, test := range tests { diff --git a/jsonschema/schema.go b/jsonschema/schema.go index 4b1d6eed..1d60de12 100644 --- a/jsonschema/schema.go +++ b/jsonschema/schema.go @@ -152,6 +152,42 @@ func (s *Schema) String() string { return "" } +// CloneSchemas returns a copy of s. +// The copy is shallow except for sub-schemas, which are themelves copied with CloneSchemas. +// This allows both s and s.CloneSchemas() to appear as sub-schemas in the same parent. +func (s *Schema) CloneSchemas() *Schema { + if s == nil { + return nil + } + s2 := *s + v := reflect.ValueOf(&s2) + for _, info := range schemaFieldInfos { + fv := v.Elem().FieldByIndex(info.sf.Index) + switch info.sf.Type { + case schemaType: + sscss := fv.Interface().(*Schema) + fv.Set(reflect.ValueOf(sscss.CloneSchemas())) + + case schemaSliceType: + slice := fv.Interface().([]*Schema) + slice = slices.Clone(slice) + for i, ss := range slice { + slice[i] = ss.CloneSchemas() + } + fv.Set(reflect.ValueOf(slice)) + + case schemaMapType: + m := fv.Interface().(map[string]*Schema) + m = maps.Clone(m) + for k, ss := range m { + m[k] = ss.CloneSchemas() + } + fv.Set(reflect.ValueOf(m)) + } + } + return &s2 +} + func (s *Schema) basicChecks() error { if s.Type != "" && s.Types != nil { return errors.New("both Type and Types are set; at most one should be") diff --git a/jsonschema/schema_test.go b/jsonschema/schema_test.go index 4ceb1ee1..4b6df511 100644 --- a/jsonschema/schema_test.go +++ b/jsonschema/schema_test.go @@ -142,3 +142,34 @@ func (s *Schema) jsonIndent() string { } return string(data) } + +func TestCloneSchemas(t *testing.T) { + ss1 := &Schema{Type: "string"} + ss2 := &Schema{Type: "integer"} + ss3 := &Schema{Type: "boolean"} + ss4 := &Schema{Type: "number"} + ss5 := &Schema{Contains: ss4} + + s1 := Schema{ + Contains: ss1, + PrefixItems: []*Schema{ss2, ss3}, + Properties: map[string]*Schema{"a": ss5}, + } + s2 := s1.CloneSchemas() + + // The clones should appear identical. + if g, w := s1.json(), s2.json(); g != w { + t.Errorf("\ngot %s\nwant %s", g, w) + } + // None of the schemas should overlap. + schemas1 := map[*Schema]bool{ss1: true, ss2: true, ss3: true, ss4: true, ss5: true} + for ss := range s2.all() { + if schemas1[ss] { + t.Errorf("uncloned schema %s", ss.json()) + } + } + // s1's original schemas should be intact. + if s1.Contains != ss1 || s1.PrefixItems[0] != ss2 || s1.PrefixItems[1] != ss3 || ss5.Contains != ss4 || s1.Properties["a"] != ss5 { + t.Errorf("s1 modified") + } +} diff --git a/mcp/tool.go b/mcp/tool.go index fc154991..ed80b660 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -89,7 +89,7 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) error { var err error if *sfield == nil { - *sfield, err = jsonschema.For[T]() + *sfield, err = jsonschema.For[T](nil) } if err != nil { return err From cb17593a5dc04ba661528e9c42988641980dc94d Mon Sep 17 00:00:00 2001 From: Jim Clark Date: Mon, 4 Aug 2025 11:46:49 -0700 Subject: [PATCH 063/221] update Subscribe/Unsubscribe handlers in ServerOpts (#231) ## Background Since the Server has a map of ServersSessions->subscriptions, and takes care of routing ResourceUpdate messages to ServerSessions, I think the decision was made to leave *ServerSession out of the SubscribeHandler/UnsubscribeHandler signatures. However, resource subscriptions are not specific to the a particular client and this information is kept private in the Session today. ## Proposal Align the Subscribe/Unsubscribe handlers with the other two handlers that are client specific (ProgressNotificationHandler, and RootListChangeHanlder), and include *ServerSession in the signature. It comes in practice because because a gateway server that needs to forward subscriptions still needs this information. The routing logic in Server is convenient and means that most servers won't need to know which ServerSessions are subscribed. But I think there are still use cases where the user will need to know. --- design/design.md | 4 ++-- mcp/mcp_test.go | 4 ++-- mcp/server.go | 8 ++++---- mcp/server_test.go | 8 ++++---- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/design/design.md b/design/design.md index bfabeac7..33dc3a63 100644 --- a/design/design.md +++ b/design/design.md @@ -776,9 +776,9 @@ If a server author wants to support resource subscriptions, they must provide ha type ServerOptions struct { ... // Function called when a client session subscribes to a resource. - SubscribeHandler func(context.Context, *SubscribeParams) error + SubscribeHandler func(context.Context, ss *ServerSession, *SubscribeParams) error // Function called when a client session unsubscribes from a resource. - UnsubscribeHandler func(context.Context, *UnsubscribeParams) error + UnsubscribeHandler func(context.Context, ss *ServerSession, *UnsubscribeParams) error } ``` diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 819edeb6..da53465c 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -78,11 +78,11 @@ func TestEndToEnd(t *testing.T) { ProgressNotificationHandler: func(context.Context, *ServerSession, *ProgressNotificationParams) { notificationChans["progress_server"] <- 0 }, - SubscribeHandler: func(context.Context, *SubscribeParams) error { + SubscribeHandler: func(context.Context, *ServerSession, *SubscribeParams) error { notificationChans["subscribe"] <- 0 return nil }, - UnsubscribeHandler: func(context.Context, *UnsubscribeParams) error { + UnsubscribeHandler: func(context.Context, *ServerSession, *UnsubscribeParams) error { notificationChans["unsubscribe"] <- 0 return nil }, diff --git a/mcp/server.go b/mcp/server.go index ba2bd0a9..c8878da3 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -66,9 +66,9 @@ type ServerOptions struct { // the session is automatically closed. KeepAlive time.Duration // Function called when a client session subscribes to a resource. - SubscribeHandler func(context.Context, *SubscribeParams) error + SubscribeHandler func(context.Context, *ServerSession, *SubscribeParams) error // Function called when a client session unsubscribes from a resource. - UnsubscribeHandler func(context.Context, *UnsubscribeParams) error + UnsubscribeHandler func(context.Context, *ServerSession, *UnsubscribeParams) error // If true, advertises the prompts capability during initialization, // even if no prompts have been registered. HasPrompts bool @@ -469,7 +469,7 @@ func (s *Server) subscribe(ctx context.Context, ss *ServerSession, params *Subsc if s.opts.SubscribeHandler == nil { return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound) } - if err := s.opts.SubscribeHandler(ctx, params); err != nil { + if err := s.opts.SubscribeHandler(ctx, ss, params); err != nil { return nil, err } @@ -488,7 +488,7 @@ func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *Uns return nil, jsonrpc2.ErrMethodNotFound } - if err := s.opts.UnsubscribeHandler(ctx, params); err != nil { + if err := s.opts.UnsubscribeHandler(ctx, ss, params); err != nil { return nil, err } diff --git a/mcp/server_test.go b/mcp/server_test.go index 0b853a33..5a161b72 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -281,10 +281,10 @@ func TestServerCapabilities(t *testing.T) { s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) }, serverOpts: ServerOptions{ - SubscribeHandler: func(ctx context.Context, sp *SubscribeParams) error { + SubscribeHandler: func(ctx context.Context, _ *ServerSession, sp *SubscribeParams) error { return nil }, - UnsubscribeHandler: func(ctx context.Context, up *UnsubscribeParams) error { + UnsubscribeHandler: func(ctx context.Context, _ *ServerSession, up *UnsubscribeParams) error { return nil }, }, @@ -325,10 +325,10 @@ func TestServerCapabilities(t *testing.T) { s.AddTool(tool, nil) }, serverOpts: ServerOptions{ - SubscribeHandler: func(ctx context.Context, sp *SubscribeParams) error { + SubscribeHandler: func(ctx context.Context, _ *ServerSession, sp *SubscribeParams) error { return nil }, - UnsubscribeHandler: func(ctx context.Context, up *UnsubscribeParams) error { + UnsubscribeHandler: func(ctx context.Context, _ *ServerSession, up *UnsubscribeParams) error { return nil }, CompletionHandler: func(ctx context.Context, ss *ServerSession, params *CompleteParams) (*CompleteResult, error) { From 90a0b1fb2378c3de445d78bf4859f1fbb850f6b1 Mon Sep 17 00:00:00 2001 From: cryo Date: Tue, 5 Aug 2025 22:46:39 +0800 Subject: [PATCH 064/221] mcp: fix comments (#241) Update inaccurate doc comment for ClientSession. --- mcp/client.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 9e3f3935..7c68870c 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -153,8 +153,8 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e // methods can be used to send requests or notifications to the server. Create // a session by calling [Client.Connect]. // -// Call [ClientSession.Close] to close the connection, or await client -// termination with [ServerSession.Wait]. +// Call [ClientSession.Close] to close the connection, or await server +// termination with [ClientSession.Wait]. type ClientSession struct { conn *jsonrpc2.Connection client *Client From 31c18fba2da33d116a9a546b6d02b19d0ff1d089 Mon Sep 17 00:00:00 2001 From: winterfx <136159170+winterfx@users.noreply.github.com> Date: Wed, 6 Aug 2025 00:05:51 +0800 Subject: [PATCH 065/221] docs: fix middleware example function signature (#236) - Correct MethodHandler type parameter to use fully qualified mcp.ServerSession - Update parameter types to use proper mcp.Params and mcp.Result instead of any - Ensure example code matches the actual API --- design/design.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/design/design.md b/design/design.md index 33dc3a63..93dc5521 100644 --- a/design/design.md +++ b/design/design.md @@ -456,8 +456,8 @@ func (s *Server) AddReceivingMiddleware(middleware ...Middleware[*ServerSession] As an example, this code adds server-side logging: ```go -func withLogging(h mcp.MethodHandler[*ServerSession]) mcp.MethodHandler[*ServerSession]{ - return func(ctx context.Context, s *mcp.ServerSession, method string, params any) (res any, err error) { +func withLogging(h mcp.MethodHandler[*mcp.ServerSession]) mcp.MethodHandler[*mcp.ServerSession]{ + return func(ctx context.Context, s *mcp.ServerSession, method string, params mcp.Params) (res mcp.Result, err error) { log.Printf("request: %s %v", method, params) defer func() { log.Printf("response: %v, %v", res, err) }() return h(ctx, s , method, params) From 272a5ac76cd4eeb583f228323c482156d0bf9116 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 5 Aug 2025 13:42:22 +0000 Subject: [PATCH 066/221] .github: add race and 1.25 tests This CL makes the following changes to our Github workflows, motivated by #227: - Add -race tests. - Add go1.25rc2 to the build matrix. - Check out code before setting up Go, so that the setup step can use dependency caching. --- .github/workflows/test.yml | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f2f8cb8e..d5009944 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,11 +1,11 @@ name: Test on: # Manual trigger - workflow_dispatch: + workflow_dispatch: push: branches: main pull_request: - + permissions: contents: read @@ -13,10 +13,10 @@ jobs: lint: runs-on: ubuntu-latest steps: - - name: Set up Go - uses: actions/setup-go@v5 - name: Check out code uses: actions/checkout@v4 + - name: Set up Go + uses: actions/setup-go@v5 - name: Check formatting run: | unformatted=$(gofmt -l .) @@ -26,18 +26,30 @@ jobs: exit 1 fi echo "All Go files are properly formatted" - + test: runs-on: ubuntu-latest strategy: matrix: - go: [ '1.23', '1.24' ] + go: [ '1.23', '1.24', '1.25.0-rc.2' ] steps: + - name: Check out code + uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - - name: Check out code - uses: actions/checkout@v4 - name: Test run: go test -v ./... + + race-test: + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@v4 + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.24' + - name: Test with -race + run: go test -v -race ./... From f4a9396942843dbbe20cb3d4d1de962e720ce887 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 5 Aug 2025 14:50:21 +0000 Subject: [PATCH 067/221] mcp: fix a bug causing premature termination of streams As reported in #227, there was a race where streams could be terminated before the final reply was written. Fix this by checking the number of outstanding requests in the same critical section that acquires outgoing messages to write. Fixes #227 --- mcp/streamable.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 3f53a689..86af05ce 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -267,9 +267,7 @@ type stream struct { // streamRequests is the set of unanswered incoming RPCs for the stream. // - // Lifecycle: requests values persist until the requests have been - // replied to by the server. Notably, NOT until they are sent to an HTTP - // response, as delivery is not guaranteed. + // Requests persist until their response data has been added to outgoing. requests map[jsonrpc.ID]struct{} } @@ -482,6 +480,7 @@ stream: t.mu.Lock() outgoing := stream.outgoing stream.outgoing = nil + nOutstanding := len(stream.requests) t.mu.Unlock() for _, data := range outgoing { @@ -493,9 +492,6 @@ stream: } } - t.mu.Lock() - nOutstanding := len(stream.requests) - t.mu.Unlock() // If all requests have been handled and replied to, we should terminate this connection. // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." // §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server From be0bd77f48d642a1c9ff6f2b43652f959f39c57b Mon Sep 17 00:00:00 2001 From: davidleitw Date: Wed, 6 Aug 2025 00:28:36 +0800 Subject: [PATCH 068/221] mcp: test nil params crash vulnerability Test that explicit nulls don't cause panics. --- mcp/shared_test.go | 144 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) diff --git a/mcp/shared_test.go b/mcp/shared_test.go index f319d80e..0aea1947 100644 --- a/mcp/shared_test.go +++ b/mcp/shared_test.go @@ -7,6 +7,7 @@ package mcp import ( "context" "encoding/json" + "fmt" "strings" "testing" ) @@ -88,3 +89,146 @@ func TestToolValidate(t *testing.T) { }) } } + +// TestNilParamsHandling tests that nil parameters don't cause panic in unmarshalParams. +// This addresses a vulnerability where missing or null parameters could crash the server. +func TestNilParamsHandling(t *testing.T) { + // Define test types for clarity + type TestArgs struct { + Name string `json:"name"` + Value int `json:"value"` + } + type TestParams = *CallToolParamsFor[TestArgs] + type TestResult = *CallToolResultFor[string] + + // Simple test handler + testHandler := func(ctx context.Context, ss *ServerSession, params TestParams) (TestResult, error) { + result := "processed: " + params.Arguments.Name + return &CallToolResultFor[string]{StructuredContent: result}, nil + } + + methodInfo := newMethodInfo(testHandler, missingParamsOK) + + // Helper function to test that unmarshalParams doesn't panic and handles nil gracefully + mustNotPanic := func(t *testing.T, rawMsg json.RawMessage, expectNil bool) Params { + t.Helper() + + defer func() { + if r := recover(); r != nil { + t.Fatalf("unmarshalParams panicked: %v", r) + } + }() + + params, err := methodInfo.unmarshalParams(rawMsg) + if err != nil { + t.Fatalf("unmarshalParams failed: %v", err) + } + + if expectNil { + if params != nil { + t.Fatalf("Expected nil params, got %v", params) + } + return params + } + + if params == nil { + t.Fatal("unmarshalParams returned unexpected nil") + } + + // Verify the result can be used safely + typedParams := params.(TestParams) + _ = typedParams.Name + _ = typedParams.Arguments.Name + _ = typedParams.Arguments.Value + + return params + } + + // Test different nil parameter scenarios - with missingParamsOK flag, nil/null should return nil + t.Run("missing_params", func(t *testing.T) { + mustNotPanic(t, nil, true) // Expect nil with missingParamsOK flag + }) + + t.Run("explicit_null", func(t *testing.T) { + mustNotPanic(t, json.RawMessage(`null`), true) // Expect nil with missingParamsOK flag + }) + + t.Run("empty_object", func(t *testing.T) { + mustNotPanic(t, json.RawMessage(`{}`), false) // Empty object should create valid params + }) + + t.Run("valid_params", func(t *testing.T) { + rawMsg := json.RawMessage(`{"name":"test","arguments":{"name":"hello","value":42}}`) + params := mustNotPanic(t, rawMsg, false) + + // For valid params, also verify the values are parsed correctly + typedParams := params.(TestParams) + if typedParams.Name != "test" { + t.Errorf("Expected name 'test', got %q", typedParams.Name) + } + if typedParams.Arguments.Name != "hello" { + t.Errorf("Expected argument name 'hello', got %q", typedParams.Arguments.Name) + } + if typedParams.Arguments.Value != 42 { + t.Errorf("Expected argument value 42, got %d", typedParams.Arguments.Value) + } + }) +} + +// TestNilParamsEdgeCases tests edge cases to ensure we don't over-fix +func TestNilParamsEdgeCases(t *testing.T) { + type TestArgs struct { + Name string `json:"name"` + Value int `json:"value"` + } + type TestParams = *CallToolParamsFor[TestArgs] + + testHandler := func(ctx context.Context, ss *ServerSession, params TestParams) (*CallToolResultFor[string], error) { + return &CallToolResultFor[string]{StructuredContent: "test"}, nil + } + + methodInfo := newMethodInfo(testHandler, missingParamsOK) + + // These should fail normally, not be treated as nil params + invalidCases := []json.RawMessage{ + json.RawMessage(""), // empty string - should error + json.RawMessage("[]"), // array - should error + json.RawMessage(`"null"`), // string "null" - should error + json.RawMessage("0"), // number - should error + json.RawMessage("false"), // boolean - should error + } + + for i, rawMsg := range invalidCases { + t.Run(fmt.Sprintf("invalid_case_%d", i), func(t *testing.T) { + params, err := methodInfo.unmarshalParams(rawMsg) + if err == nil && params == nil { + t.Error("Should not return nil params without error") + } + }) + } + + // Test that methods without missingParamsOK flag properly reject nil params + t.Run("reject_when_params_required", func(t *testing.T) { + methodInfoStrict := newMethodInfo(testHandler, 0) // No missingParamsOK flag + + testCases := []struct { + name string + params json.RawMessage + }{ + {"nil_params", nil}, + {"null_params", json.RawMessage(`null`)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := methodInfoStrict.unmarshalParams(tc.params) + if err == nil { + t.Error("Expected error for required params, got nil") + } + if !strings.Contains(err.Error(), "missing required \"params\"") { + t.Errorf("Expected 'missing required params' error, got: %v", err) + } + }) + } + }) +} From 032f03bbeafa5702d66ebd84edf138fc2c74429c Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 5 Aug 2025 16:32:06 -0400 Subject: [PATCH 069/221] jsonschema: add ForType function (#242) Add ForType, which is like For but takes a reflect.Type. Fixes #233. --- jsonschema/infer.go | 41 +++++++++++++++++++++++++++++----------- jsonschema/infer_test.go | 37 +++++++++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 12 deletions(-) diff --git a/jsonschema/infer.go b/jsonschema/infer.go index 7b6b7e2b..080d9d09 100644 --- a/jsonschema/infer.go +++ b/jsonschema/infer.go @@ -9,6 +9,7 @@ package jsonschema import ( "fmt" "log/slog" + "maps" "math/big" "reflect" "regexp" @@ -19,8 +20,8 @@ import ( // ForOptions are options for the [For] function. type ForOptions struct { - // If IgnoreInvalidTypes is true, fields that can't be represented as a JSON Schema - // are ignored instead of causing an error. + // If IgnoreInvalidTypes is true, fields that can't be represented as a JSON + // Schema are ignored instead of causing an error. // This allows callers to adjust the resulting schema using custom knowledge. // For example, an interface type where all the possible implementations are // known can be described with "oneof". @@ -77,15 +78,7 @@ func For[T any](opts *ForOptions) (*Schema, error) { if opts == nil { opts = &ForOptions{} } - schemas := make(map[reflect.Type]*Schema) - // Add types from the standard library that have MarshalJSON methods. - ss := &Schema{Type: "string"} - schemas[reflect.TypeFor[time.Time]()] = ss - schemas[reflect.TypeFor[slog.Level]()] = ss - schemas[reflect.TypeFor[big.Int]()] = &Schema{Types: []string{"null", "string"}} - schemas[reflect.TypeFor[big.Rat]()] = ss - schemas[reflect.TypeFor[big.Float]()] = ss - + schemas := maps.Clone(initialSchemaMap) // Add types from the options. They override the default ones. for v, s := range opts.TypeSchemas { schemas[reflect.TypeOf(v)] = s @@ -98,6 +91,20 @@ func For[T any](opts *ForOptions) (*Schema, error) { return s, nil } +// ForType is like [For], but takes a [reflect.Type] +func ForType(t reflect.Type, opts *ForOptions) (*Schema, error) { + schemas := maps.Clone(initialSchemaMap) + // Add types from the options. They override the default ones. + for v, s := range opts.TypeSchemas { + schemas[reflect.TypeOf(v)] = s + } + s, err := forType(t, map[reflect.Type]bool{}, opts.IgnoreInvalidTypes, schemas) + if err != nil { + return nil, fmt.Errorf("ForType(%s): %w", t, err) + } + return s, nil +} + func forType(t reflect.Type, seen map[reflect.Type]bool, ignore bool, schemas map[reflect.Type]*Schema) (*Schema, error) { // Follow pointers: the schema for *T is almost the same as for T, except that // an explicit JSON "null" is allowed for the pointer. @@ -230,5 +237,17 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, ignore bool, schemas ma return s, nil } +// initialSchemaMap holds types from the standard library that have MarshalJSON methods. +var initialSchemaMap = make(map[reflect.Type]*Schema) + +func init() { + ss := &Schema{Type: "string"} + initialSchemaMap[reflect.TypeFor[time.Time]()] = ss + initialSchemaMap[reflect.TypeFor[slog.Level]()] = ss + initialSchemaMap[reflect.TypeFor[big.Int]()] = &Schema{Types: []string{"null", "string"}} + initialSchemaMap[reflect.TypeFor[big.Rat]()] = ss + initialSchemaMap[reflect.TypeFor[big.Float]()] = ss +} + // Disallow jsonschema tag values beginning "WORD=", for future expansion. var disallowedPrefixRegexp = regexp.MustCompile("^[^ \t\n]*=") diff --git a/jsonschema/infer_test.go b/jsonschema/infer_test.go index 1a0895b4..62bfbbbc 100644 --- a/jsonschema/infer_test.go +++ b/jsonschema/infer_test.go @@ -7,6 +7,7 @@ package jsonschema_test import ( "log/slog" "math/big" + "reflect" "strings" "testing" "time" @@ -139,7 +140,7 @@ func TestFor(t *testing.T) { run := func(t *testing.T, tt test) { if diff := cmp.Diff(tt.want, tt.got, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Fatalf("ForType mismatch (-want +got):\n%s", diff) + t.Fatalf("For mismatch (-want +got):\n%s", diff) } // These schemas should all resolve. if _, err := tt.got.Resolve(nil); err != nil { @@ -176,6 +177,40 @@ func TestFor(t *testing.T) { }) } +func TestForType(t *testing.T) { + type schema = jsonschema.Schema + + // ForType is virtually identical to For. Just test that options are handled properly. + opts := &jsonschema.ForOptions{ + IgnoreInvalidTypes: true, + TypeSchemas: map[any]*jsonschema.Schema{ + custom(0): {Type: "custom"}, + }, + } + + type S struct { + I int + F func() + C custom + } + got, err := jsonschema.ForType(reflect.TypeOf(S{}), opts) + if err != nil { + t.Fatal(err) + } + want := &schema{ + Type: "object", + Properties: map[string]*schema{ + "I": {Type: "integer"}, + "C": {Type: "custom"}, + }, + Required: []string{"I", "C"}, + AdditionalProperties: falseSchema(), + } + if diff := cmp.Diff(want, got, cmpopts.IgnoreUnexported(schema{})); diff != "" { + t.Fatalf("ForType mismatch (-want +got):\n%s", diff) + } +} + func forErr[T any]() error { _, err := jsonschema.For[T](nil) return err From a02c2ffd0e7fa8d1417dd4b85cd4c0890b15238b Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Mon, 4 Aug 2025 23:20:55 +0000 Subject: [PATCH 070/221] mcp: unblock reads of ioConn as soon as Close is called Use a read loop and incoming channel so that we can select on incoming messages and unblock calls of ioConn.Read as soon as Close is called. This avoids the scenario of #224, where a close of the StdioTransport does not gracefully shut down the JSON-RPC connection. Fixes #224 --- mcp/cmd_test.go | 95 +++++++++++++++++++++++++++++++++++++++++------- mcp/transport.go | 71 +++++++++++++++++++++++++++++++----- 2 files changed, 142 insertions(+), 24 deletions(-) diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index bc149f4c..82a35a80 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -10,7 +10,9 @@ import ( "log" "os" "os/exec" + "os/signal" "runtime" + "syscall" "testing" "time" @@ -21,14 +23,27 @@ import ( const runAsServer = "_MCP_RUN_AS_SERVER" func TestMain(m *testing.M) { - if os.Getenv(runAsServer) != "" { + // If the runAsServer variable is set, execute the relevant serverFunc + // instead of running tests (aka the fork and exec trick). + if name := os.Getenv(runAsServer); name != "" { + run := serverFuncs[name] + if run == nil { + log.Fatalf("Unknown server %q", name) + } os.Unsetenv(runAsServer) - runServer() + run() return } os.Exit(m.Run()) } +// serverFuncs defines server functions that may be run as subprocesses via +// [TestMain]. +var serverFuncs = map[string]func(){ + "default": runServer, + "cancelContext": runCancelContextServer, +} + func runServer() { ctx := context.Background() @@ -39,6 +54,16 @@ func runServer() { } } +func runCancelContextServer() { + ctx, done := signal.NotifyContext(context.Background(), syscall.SIGINT) + defer done() + + server := mcp.NewServer(testImpl, nil) + if err := server.Run(ctx, mcp.NewStdioTransport()); err != nil { + log.Fatal(err) + } +} + func TestServerRunContextCancel(t *testing.T) { server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v0.0.1"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) @@ -80,15 +105,18 @@ func TestServerRunContextCancel(t *testing.T) { } func TestServerInterrupt(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("requires POSIX signals") + } requireExec(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - cmd := createServerCommand(t) + cmd := createServerCommand(t, "default") client := mcp.NewClient(testImpl, nil) - session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) + _, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) if err != nil { t.Fatal(err) } @@ -101,19 +129,54 @@ func TestServerInterrupt(t *testing.T) { }() // send a signal to the server process to terminate it + cmd.Process.Signal(os.Interrupt) + + // wait for the server to exit + // TODO: use synctest when available + select { + case <-time.After(5 * time.Second): + t.Fatal("server did not exit after SIGINT") + case <-onExit: + } +} + +func TestStdioContextCancellation(t *testing.T) { if runtime.GOOS == "windows" { - // Windows does not support os.Interrupt - session.Close() - } else { - cmd.Process.Signal(os.Interrupt) + t.Skip("requires POSIX signals") } + requireExec(t) + + // This test is a variant of TestServerInterrupt reproducing the conditions + // of #224, where interrupt failed to shut down the server because reads of + // Stdin were not unblocked. + + cmd := createServerCommand(t, "cancelContext") + // Creating a stdin pipe causes os.Stdin.Close to not immediately unblock + // pending reads. + _, _ = cmd.StdinPipe() + + // Just Start the command, rather than connecting to the server, because we + // don't want the client connection to indirectly flush stdin through writes. + if err := cmd.Start(); err != nil { + t.Fatalf("starting command: %v", err) + } + + // Sleep to make it more likely that the server is blocked in the read loop. + time.Sleep(100 * time.Millisecond) + + onExit := make(chan struct{}) + go func() { + cmd.Process.Wait() + close(onExit) + }() + + cmd.Process.Signal(os.Interrupt) - // wait for the server to exit - // TODO: use synctest when availble select { case <-time.After(5 * time.Second): - t.Fatal("server did not exit after SIGTERM") + t.Fatal("server did not exit after SIGINT") case <-onExit: + t.Logf("done.") } } @@ -123,7 +186,7 @@ func TestCmdTransport(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - cmd := createServerCommand(t) + cmd := createServerCommand(t, "default") client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) @@ -150,7 +213,11 @@ func TestCmdTransport(t *testing.T) { } } -func createServerCommand(t *testing.T) *exec.Cmd { +// createServerCommand creates a command to fork and exec the test binary as an +// MCP server. +// +// serverName must refer to an entry in the [serverFuncs] map. +func createServerCommand(t *testing.T, serverName string) *exec.Cmd { t.Helper() exe, err := os.Executable() @@ -158,7 +225,7 @@ func createServerCommand(t *testing.T) *exec.Cmd { t.Fatal(err) } cmd := exec.Command(exe) - cmd.Env = append(os.Environ(), runAsServer+"=true") + cmd.Env = append(os.Environ(), runAsServer+"="+serverName) return cmd } diff --git a/mcp/transport.go b/mcp/transport.go index 5175f6f0..dccc920b 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -265,7 +265,9 @@ func (r rwc) Close() error { // See [msgBatch] for more discussion of message batching. type ioConn struct { rwc io.ReadWriteCloser // the underlying stream - in *json.Decoder // a decoder bound to rwc + + // incoming receives messages from the read loop started in [newIOConn]. + incoming <-chan msgOrErr // If outgoiBatch has a positive capacity, it will be used to batch requests // and notifications before sending. @@ -279,12 +281,47 @@ type ioConn struct { // Since writes may be concurrent to reads, we need to guard this with a mutex. batchMu sync.Mutex batches map[jsonrpc2.ID]*msgBatch // lazily allocated + + closeOnce sync.Once + closed chan struct{} + closeErr error +} + +type msgOrErr struct { + msg json.RawMessage + err error } func newIOConn(rwc io.ReadWriteCloser) *ioConn { + var ( + incoming = make(chan msgOrErr) + closed = make(chan struct{}) + ) + // Start a goroutine for reads, so that we can select on the incoming channel + // in [ioConn.Read] and unblock the read as soon as Close is called (see #224). + // + // This leaks a goroutine, but that is unavoidable since AFAIK there is no + // (easy and portable) way to guarantee that reads of stdin are unblocked + // when closed. + go func() { + dec := json.NewDecoder(rwc) + for { + var raw json.RawMessage + err := dec.Decode(&raw) + select { + case incoming <- msgOrErr{msg: raw, err: err}: + case <-closed: + return + } + if err != nil { + return + } + } + }() return &ioConn{ - rwc: rwc, - in: json.NewDecoder(rwc), + rwc: rwc, + incoming: incoming, + closed: closed, } } @@ -356,10 +393,8 @@ type msgBatch struct { } func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { - return t.read(ctx, t.in) -} - -func (t *ioConn) read(ctx context.Context, in *json.Decoder) (jsonrpc.Message, error) { + // As a matter of principle, enforce that reads on a closed context return an + // error. select { case <-ctx.Done(): return nil, ctx.Err() @@ -372,9 +407,20 @@ func (t *ioConn) read(ctx context.Context, in *json.Decoder) (jsonrpc.Message, e } var raw json.RawMessage - if err := in.Decode(&raw); err != nil { - return nil, err + select { + case <-ctx.Done(): + return nil, ctx.Err() + + case v := <-t.incoming: + if v.err != nil { + return nil, v.err + } + raw = v.msg + + case <-t.closed: + return nil, io.EOF } + msgs, batch, err := readBatch(raw) if err != nil { return nil, err @@ -431,6 +477,7 @@ func readBatch(data []byte) (msgs []jsonrpc.Message, isBatch bool, _ error) { } func (t *ioConn) Write(ctx context.Context, msg jsonrpc.Message) error { + // As in [ioConn.Read], enforce that Writes on a closed context are an error. select { case <-ctx.Done(): return ctx.Err() @@ -478,7 +525,11 @@ func (t *ioConn) Write(ctx context.Context, msg jsonrpc.Message) error { } func (t *ioConn) Close() error { - return t.rwc.Close() + t.closeOnce.Do(func() { + t.closeErr = t.rwc.Close() + close(t.closed) + }) + return t.closeErr } func marshalMessages[T jsonrpc.Message](msgs []T) ([]byte, error) { From bb9087d2526b216b902095abbf64a5ab728ef8a4 Mon Sep 17 00:00:00 2001 From: Rishabh Nrupnarayan Date: Wed, 6 Aug 2025 21:10:39 +0530 Subject: [PATCH 071/221] mcp: handle bad trailing stdio input graciously (#179) (#192) You can make MCP server log error, send response and then abruptly exit, with a json input malformed at the end. While using json.RawMessage, once a valid first json is decoded, trailing bad input is silently ignored. Without graciously handling this input, mcp server is currently sending response as well as encountering error. It should just report error without further processing of request. Fixes #179. --- mcp/transport.go | 13 +++++++++++++ mcp/transport_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/mcp/transport.go b/mcp/transport.go index dccc920b..f2d5c72d 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -308,6 +308,19 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn { for { var raw json.RawMessage err := dec.Decode(&raw) + // If decoding was successful, check for trailing data at the end of the stream. + if err == nil { + // Read the next byte to check if there is trailing data. + var tr [1]byte + if n, readErr := dec.Buffered().Read(tr[:]); n > 0 { + // If read byte is not a newline, it is an error. + if tr[0] != '\n' { + err = fmt.Errorf("invalid trailing data at the end of stream") + } + } else if readErr != nil && readErr != io.EOF { + err = readErr + } + } select { case incoming <- msgOrErr{msg: raw, err: err}: case <-closed: diff --git a/mcp/transport_test.go b/mcp/transport_test.go index c63b84ee..18a326e8 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -7,6 +7,7 @@ package mcp import ( "context" "io" + "strings" "testing" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" @@ -51,3 +52,41 @@ func TestBatchFraming(t *testing.T) { } } } + +func TestIOConnRead(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + + { + name: "valid json input", + input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}}`, + want: "", + }, + + { + name: "newline at the end of first valid json input", + input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}} + `, + want: "", + }, + { + name: "bad data at the end of first valid json input", + input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}},`, + want: "invalid trailing data at the end of stream", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := newIOConn(rwc{ + rc: io.NopCloser(strings.NewReader(tt.input)), + }) + _, err := tr.Read(context.Background()) + if err != nil && err.Error() != tt.want { + t.Errorf("ioConn.Read() = %v, want %v", err.Error(), tt.want) + } + }) + } +} From c132621eb047dab7897b00e85769be42c9921bf5 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 6 Aug 2025 11:41:02 -0400 Subject: [PATCH 072/221] jsonschema: marshal to bools (#247) Marshal the empty schema to true, and its negation to false. This is technically a breaking behavior change, but a properly written consumer of JSON Schema will not notice. For #244. Fixes #230. --- jsonschema/schema.go | 15 ++++++++++++++- jsonschema/schema_test.go | 6 +++--- mcp/testdata/conformance/server/tools.txtar | 4 +--- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/jsonschema/schema.go b/jsonschema/schema.go index 1d60de12..0b4764d6 100644 --- a/jsonschema/schema.go +++ b/jsonschema/schema.go @@ -204,6 +204,7 @@ func (s *Schema) MarshalJSON() ([]byte, error) { if err := s.basicChecks(); err != nil { return nil, err } + // Marshal either Type or Types as "type". var typ any switch { @@ -219,7 +220,19 @@ func (s *Schema) MarshalJSON() ([]byte, error) { Type: typ, schemaWithoutMethods: (*schemaWithoutMethods)(s), } - return marshalStructWithMap(&ms, "Extra") + bs, err := marshalStructWithMap(&ms, "Extra") + if err != nil { + return nil, err + } + // Marshal {} as true and {"not": {}} as false. + // It is wasteful to do this here instead of earlier, but much easier. + switch { + case bytes.Equal(bs, []byte(`{}`)): + bs = []byte("true") + case bytes.Equal(bs, []byte(`{"not":true}`)): + bs = []byte("false") + } + return bs, nil } func (s *Schema) UnmarshalJSON(data []byte) error { diff --git a/jsonschema/schema_test.go b/jsonschema/schema_test.go index 4b6df511..19f6c6c7 100644 --- a/jsonschema/schema_test.go +++ b/jsonschema/schema_test.go @@ -54,9 +54,9 @@ func TestJSONRoundTrip(t *testing.T) { for _, tt := range []struct { in, want string }{ - {`true`, `{}`}, // boolean schemas become object schemas - {`false`, `{"not":{}}`}, - {`{"type":"", "enum":null}`, `{}`}, // empty fields are omitted + {`true`, `true`}, + {`false`, `false`}, + {`{"type":"", "enum":null}`, `true`}, // empty fields are omitted {`{"minimum":1}`, `{"minimum":1}`}, {`{"minimum":1.0}`, `{"minimum":1}`}, // floating-point integers lose their fractional part {`{"minLength":1.0}`, `{"minLength":1}`}, // some floats are unmarshaled into ints, but you can't tell diff --git a/mcp/testdata/conformance/server/tools.txtar b/mcp/testdata/conformance/server/tools.txtar index b4068d1c..01fdb266 100644 --- a/mcp/testdata/conformance/server/tools.txtar +++ b/mcp/testdata/conformance/server/tools.txtar @@ -59,9 +59,7 @@ greet "type": "string" } }, - "additionalProperties": { - "not": {} - } + "additionalProperties": false }, "name": "greet" } From e00c4859698408ae7e13d4d6d5e6d5df606ba558 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 30 Jul 2025 19:33:15 +0000 Subject: [PATCH 073/221] examples: add a 'listeverything' client example, and reorganize Add a client example that lists all features on a stdio server. Since this is our first client example, organize examples by client / server -- otherwise it is too difficult to find client examples. For #33 --- examples/client/listfeatures/main.go | 65 +++++++++++++++++++ examples/{ => server}/completion/main.go | 0 .../{ => server}/custom-transport/main.go | 0 examples/{ => server}/hello/main.go | 0 examples/{ => server}/memory/kb.go | 0 examples/{ => server}/memory/kb_test.go | 0 examples/{ => server}/memory/main.go | 0 examples/{ => server}/rate-limiting/go.mod | 0 examples/{ => server}/rate-limiting/go.sum | 0 examples/{ => server}/rate-limiting/main.go | 0 .../{ => server}/sequentialthinking/README.md | 0 .../{ => server}/sequentialthinking/main.go | 0 .../sequentialthinking/main_test.go | 0 examples/{ => server}/sse/main.go | 0 14 files changed, 65 insertions(+) create mode 100644 examples/client/listfeatures/main.go rename examples/{ => server}/completion/main.go (100%) rename examples/{ => server}/custom-transport/main.go (100%) rename examples/{ => server}/hello/main.go (100%) rename examples/{ => server}/memory/kb.go (100%) rename examples/{ => server}/memory/kb_test.go (100%) rename examples/{ => server}/memory/main.go (100%) rename examples/{ => server}/rate-limiting/go.mod (100%) rename examples/{ => server}/rate-limiting/go.sum (100%) rename examples/{ => server}/rate-limiting/main.go (100%) rename examples/{ => server}/sequentialthinking/README.md (100%) rename examples/{ => server}/sequentialthinking/main.go (100%) rename examples/{ => server}/sequentialthinking/main_test.go (100%) rename examples/{ => server}/sse/main.go (100%) diff --git a/examples/client/listfeatures/main.go b/examples/client/listfeatures/main.go new file mode 100644 index 00000000..caf21bfe --- /dev/null +++ b/examples/client/listfeatures/main.go @@ -0,0 +1,65 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The listfeatures command lists all features of a stdio MCP server. +// +// Usage: listfeatures [] +// +// For example: +// +// listfeatures go run github.com/modelcontextprotocol/go-sdk/examples/server/hello +// +// or +// +// listfeatures npx @modelcontextprotocol/server-everything +package main + +import ( + "context" + "flag" + "fmt" + "iter" + "log" + "os" + "os/exec" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func main() { + flag.Parse() + args := flag.Args() + if len(args) == 0 { + fmt.Fprintf(os.Stderr, "Usage: listfeatures []") + fmt.Fprintf(os.Stderr, "List all features for a stdio MCP server") + fmt.Fprintln(os.Stderr) + fmt.Fprintf(os.Stderr, "Example: listfeatures npx @modelcontextprotocol/server-everything") + os.Exit(2) + } + + ctx := context.Background() + cmd := exec.Command(args[0], args[1:]...) + client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) + cs, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + printSection("tools", cs.Tools(ctx, nil), func(t *mcp.Tool) string { return t.Name }) + printSection("resources", cs.Resources(ctx, nil), func(r *mcp.Resource) string { return r.Name }) + printSection("resource templates", cs.ResourceTemplates(ctx, nil), func(r *mcp.ResourceTemplate) string { return r.Name }) + printSection("prompts", cs.Prompts(ctx, nil), func(p *mcp.Prompt) string { return p.Name }) +} + +func printSection[T any](name string, features iter.Seq2[T, error], featName func(T) string) { + fmt.Printf("%s:\n", name) + for feat, err := range features { + if err != nil { + log.Fatal(err) + } + fmt.Printf("\t%s\n", featName(feat)) + } + fmt.Println() +} diff --git a/examples/completion/main.go b/examples/server/completion/main.go similarity index 100% rename from examples/completion/main.go rename to examples/server/completion/main.go diff --git a/examples/custom-transport/main.go b/examples/server/custom-transport/main.go similarity index 100% rename from examples/custom-transport/main.go rename to examples/server/custom-transport/main.go diff --git a/examples/hello/main.go b/examples/server/hello/main.go similarity index 100% rename from examples/hello/main.go rename to examples/server/hello/main.go diff --git a/examples/memory/kb.go b/examples/server/memory/kb.go similarity index 100% rename from examples/memory/kb.go rename to examples/server/memory/kb.go diff --git a/examples/memory/kb_test.go b/examples/server/memory/kb_test.go similarity index 100% rename from examples/memory/kb_test.go rename to examples/server/memory/kb_test.go diff --git a/examples/memory/main.go b/examples/server/memory/main.go similarity index 100% rename from examples/memory/main.go rename to examples/server/memory/main.go diff --git a/examples/rate-limiting/go.mod b/examples/server/rate-limiting/go.mod similarity index 100% rename from examples/rate-limiting/go.mod rename to examples/server/rate-limiting/go.mod diff --git a/examples/rate-limiting/go.sum b/examples/server/rate-limiting/go.sum similarity index 100% rename from examples/rate-limiting/go.sum rename to examples/server/rate-limiting/go.sum diff --git a/examples/rate-limiting/main.go b/examples/server/rate-limiting/main.go similarity index 100% rename from examples/rate-limiting/main.go rename to examples/server/rate-limiting/main.go diff --git a/examples/sequentialthinking/README.md b/examples/server/sequentialthinking/README.md similarity index 100% rename from examples/sequentialthinking/README.md rename to examples/server/sequentialthinking/README.md diff --git a/examples/sequentialthinking/main.go b/examples/server/sequentialthinking/main.go similarity index 100% rename from examples/sequentialthinking/main.go rename to examples/server/sequentialthinking/main.go diff --git a/examples/sequentialthinking/main_test.go b/examples/server/sequentialthinking/main_test.go similarity index 100% rename from examples/sequentialthinking/main_test.go rename to examples/server/sequentialthinking/main_test.go diff --git a/examples/sse/main.go b/examples/server/sse/main.go similarity index 100% rename from examples/sse/main.go rename to examples/server/sse/main.go From b392875dc5aeb6da971fcf3da4b15179dd25d71c Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Wed, 6 Aug 2025 12:07:57 -0400 Subject: [PATCH 074/221] internal/util: move FieldJSONInfo into jsonschema/util (#251) This CL effectively removes any dependencies on mcp from jsonschema. This will enable us to move the repo. This CL also deletes unused marshalStructWithMap and unmarshalStructWithMap which had a dependency on FieldJSONInfo. Fixes: #250 --- internal/util/util.go | 39 ------------- internal/util/util_test.go | 38 ------------ jsonschema/infer.go | 12 ++-- jsonschema/schema.go | 8 +-- jsonschema/util.go | 46 +++++++++++++-- jsonschema/util_test.go | 28 +++++++++ jsonschema/validate.go | 10 ++-- mcp/util.go | 117 ------------------------------------- mcp/util_test.go | 48 --------------- 9 files changed, 82 insertions(+), 264 deletions(-) delete mode 100644 internal/util/util_test.go delete mode 100644 mcp/util_test.go diff --git a/internal/util/util.go b/internal/util/util.go index f8c5baf9..4b5c325f 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -8,9 +8,7 @@ import ( "cmp" "fmt" "iter" - "reflect" "slices" - "strings" ) // Helpers below are copied from gopls' moremaps package. @@ -38,43 +36,6 @@ func KeySlice[M ~map[K]V, K comparable, V any](m M) []K { return r } -type JSONInfo struct { - Omit bool // unexported or first tag element is "-" - Name string // Go field name or first tag element. Empty if Omit is true. - Settings map[string]bool // "omitempty", "omitzero", etc. -} - -// FieldJSONInfo reports information about how encoding/json -// handles the given struct field. -// If the field is unexported, JSONInfo.Omit is true and no other JSONInfo field -// is populated. -// If the field is exported and has no tag, then Name is the field's name and all -// other fields are false. -// Otherwise, the information is obtained from the tag. -func FieldJSONInfo(f reflect.StructField) JSONInfo { - if !f.IsExported() { - return JSONInfo{Omit: true} - } - info := JSONInfo{Name: f.Name} - if tag, ok := f.Tag.Lookup("json"); ok { - name, rest, found := strings.Cut(tag, ",") - // "-" means omit, but "-," means the name is "-" - if name == "-" && !found { - return JSONInfo{Omit: true} - } - if name != "" { - info.Name = name - } - if len(rest) > 0 { - info.Settings = map[string]bool{} - for _, s := range strings.Split(rest, ",") { - info.Settings[s] = true - } - } - } - return info -} - // Wrapf wraps *errp with the given formatted message if *errp is not nil. func Wrapf(errp *error, format string, args ...any) { if *errp != nil { diff --git a/internal/util/util_test.go b/internal/util/util_test.go deleted file mode 100644 index 6a2b8676..00000000 --- a/internal/util/util_test.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package util - -import ( - "reflect" - "testing" -) - -func TestJSONInfo(t *testing.T) { - type S struct { - A int - B int `json:","` - C int `json:"-"` - D int `json:"-,"` - E int `json:"echo"` - F int `json:"foxtrot,omitempty"` - g int `json:"golf"` - } - want := []JSONInfo{ - {Name: "A"}, - {Name: "B"}, - {Omit: true}, - {Name: "-"}, - {Name: "echo"}, - {Name: "foxtrot", Settings: map[string]bool{"omitempty": true}}, - {Omit: true}, - } - tt := reflect.TypeFor[S]() - for i := range tt.NumField() { - got := FieldJSONInfo(tt.Field(i)) - if !reflect.DeepEqual(got, want[i]) { - t.Errorf("got %+v, want %+v", got, want[i]) - } - } -} diff --git a/jsonschema/infer.go b/jsonschema/infer.go index 080d9d09..aa01fcd2 100644 --- a/jsonschema/infer.go +++ b/jsonschema/infer.go @@ -14,8 +14,6 @@ import ( "reflect" "regexp" "time" - - "github.com/modelcontextprotocol/go-sdk/internal/util" ) // ForOptions are options for the [For] function. @@ -192,8 +190,8 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, ignore bool, schemas ma for i := range t.NumField() { field := t.Field(i) - info := util.FieldJSONInfo(field) - if info.Omit { + info := fieldJSONInfo(field) + if info.omit { continue } if s.Properties == nil { @@ -216,9 +214,9 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, ignore bool, schemas ma } fs.Description = tag } - s.Properties[info.Name] = fs - if !info.Settings["omitempty"] && !info.Settings["omitzero"] { - s.Required = append(s.Required, info.Name) + s.Properties[info.name] = fs + if !info.settings["omitempty"] && !info.settings["omitzero"] { + s.Required = append(s.Required, info.name) } } diff --git a/jsonschema/schema.go b/jsonschema/schema.go index 0b4764d6..9a68cd5d 100644 --- a/jsonschema/schema.go +++ b/jsonschema/schema.go @@ -15,8 +15,6 @@ import ( "math" "reflect" "slices" - - "github.com/modelcontextprotocol/go-sdk/internal/util" ) // A Schema is a JSON schema object. @@ -424,9 +422,9 @@ var ( func init() { for _, sf := range reflect.VisibleFields(reflect.TypeFor[Schema]()) { - info := util.FieldJSONInfo(sf) - if !info.Omit { - schemaFieldInfos = append(schemaFieldInfos, structFieldInfo{sf, info.Name}) + info := fieldJSONInfo(sf) + if !info.omit { + schemaFieldInfos = append(schemaFieldInfos, structFieldInfo{sf, info.name}) } } slices.SortFunc(schemaFieldInfos, func(i1, i2 structFieldInfo) int { diff --git a/jsonschema/util.go b/jsonschema/util.go index 71c34439..f5291536 100644 --- a/jsonschema/util.go +++ b/jsonschema/util.go @@ -15,9 +15,8 @@ import ( "math/big" "reflect" "slices" + "strings" "sync" - - "github.com/modelcontextprotocol/go-sdk/internal/util" ) // Equal reports whether two Go values representing JSON values are equal according @@ -410,11 +409,48 @@ func jsonNames(t reflect.Type) map[string]bool { } continue } - info := util.FieldJSONInfo(field) - if !info.Omit { - m[info.Name] = true + info := fieldJSONInfo(field) + if !info.omit { + m[info.name] = true } } jsonNamesMap.Store(t, m) return m } + +type jsonInfo struct { + omit bool // unexported or first tag element is "-" + name string // Go field name or first tag element. Empty if omit is true. + settings map[string]bool // "omitempty", "omitzero", etc. +} + +// fieldJSONInfo reports information about how encoding/json +// handles the given struct field. +// If the field is unexported, jsonInfo.omit is true and no other jsonInfo field +// is populated. +// If the field is exported and has no tag, then name is the field's name and all +// other fields are false. +// Otherwise, the information is obtained from the tag. +func fieldJSONInfo(f reflect.StructField) jsonInfo { + if !f.IsExported() { + return jsonInfo{omit: true} + } + info := jsonInfo{name: f.Name} + if tag, ok := f.Tag.Lookup("json"); ok { + name, rest, found := strings.Cut(tag, ",") + // "-" means omit, but "-," means the name is "-" + if name == "-" && !found { + return jsonInfo{omit: true} + } + if name != "" { + info.name = name + } + if len(rest) > 0 { + info.settings = map[string]bool{} + for _, s := range strings.Split(rest, ",") { + info.settings[s] = true + } + } + } + return info +} diff --git a/jsonschema/util_test.go b/jsonschema/util_test.go index 03ccb4d7..7934bff7 100644 --- a/jsonschema/util_test.go +++ b/jsonschema/util_test.go @@ -184,3 +184,31 @@ func TestMarshalStructWithMap(t *testing.T) { } }) } + +func TestJSONInfo(t *testing.T) { + type S struct { + A int + B int `json:","` + C int `json:"-"` + D int `json:"-,"` + E int `json:"echo"` + F int `json:"foxtrot,omitempty"` + g int `json:"golf"` + } + want := []jsonInfo{ + {name: "A"}, + {name: "B"}, + {omit: true}, + {name: "-"}, + {name: "echo"}, + {name: "foxtrot", settings: map[string]bool{"omitempty": true}}, + {omit: true}, + } + tt := reflect.TypeFor[S]() + for i := range tt.NumField() { + got := fieldJSONInfo(tt.Field(i)) + if !reflect.DeepEqual(got, want[i]) { + t.Errorf("got %+v, want %+v", got, want[i]) + } + } +} diff --git a/jsonschema/validate.go b/jsonschema/validate.go index 3b864107..ca4f340c 100644 --- a/jsonschema/validate.go +++ b/jsonschema/validate.go @@ -697,8 +697,8 @@ func properties(v reflect.Value) iter.Seq2[string, reflect.Value] { for name, sf := range structPropertiesOf(v.Type()) { val := v.FieldByIndex(sf.Index) if val.IsZero() { - info := util.FieldJSONInfo(sf) - if info.Settings["omitempty"] || info.Settings["omitzero"] { + info := fieldJSONInfo(sf) + if info.settings["omitempty"] || info.settings["omitzero"] { continue } } @@ -750,9 +750,9 @@ func structPropertiesOf(t reflect.Type) propertyMap { } props := map[string]reflect.StructField{} for _, sf := range reflect.VisibleFields(t) { - info := util.FieldJSONInfo(sf) - if !info.Omit { - props[info.Name] = sf + info := fieldJSONInfo(sf) + if !info.omit { + props[info.name] = sf } } structProperties.Store(t, props) diff --git a/mcp/util.go b/mcp/util.go index 82d87940..102c0885 100644 --- a/mcp/util.go +++ b/mcp/util.go @@ -6,12 +6,6 @@ package mcp import ( "crypto/rand" - "encoding/json" - "fmt" - "reflect" - "sync" - - "github.com/modelcontextprotocol/go-sdk/internal/util" ) func assert(cond bool, msg string) { @@ -33,114 +27,3 @@ func randText() string { } return string(src) } - -// marshalStructWithMap marshals its first argument to JSON, treating the field named -// mapField as an embedded map. The first argument must be a pointer to -// a struct. The underlying type of mapField must be a map[string]any, and it must have -// an "omitempty" json tag. -// -// For example, given this struct: -// -// type S struct { -// A int -// Extra map[string] any `json:,omitempty` -// } -// -// and this value: -// -// s := S{A: 1, Extra: map[string]any{"B": 2}} -// -// the call marshalJSONWithMap(s, "Extra") would return -// -// {"A": 1, "B": 2} -// -// It is an error if the map contains the same key as another struct field's -// JSON name. -// -// marshalStructWithMap calls json.Marshal on a value of type T, so T must not -// have a MarshalJSON method that calls this function, on pain of infinite regress. -// -// TODO: avoid this restriction on T by forcing it to marshal in a default way. -// See https://go.dev/play/p/EgXKJHxEx_R. -func marshalStructWithMap[T any](s *T, mapField string) ([]byte, error) { - // Marshal the struct and the map separately, and concatenate the bytes. - // This strategy is dramatically less complicated than - // constructing a synthetic struct or map with the combined keys. - if s == nil { - return []byte("null"), nil - } - s2 := *s - vMapField := reflect.ValueOf(&s2).Elem().FieldByName(mapField) - mapVal := vMapField.Interface().(map[string]any) - - // Check for duplicates. - names := jsonNames(reflect.TypeFor[T]()) - for key := range mapVal { - if names[key] { - return nil, fmt.Errorf("map key %q duplicates struct field", key) - } - } - - // Clear the map field, relying on the omitempty tag to omit it. - vMapField.Set(reflect.Zero(vMapField.Type())) - structBytes, err := json.Marshal(s2) - if err != nil { - return nil, fmt.Errorf("marshalStructWithMap(%+v): %w", s, err) - } - if len(mapVal) == 0 { - return structBytes, nil - } - mapBytes, err := json.Marshal(mapVal) - if err != nil { - return nil, err - } - if len(structBytes) == 2 { // must be "{}" - return mapBytes, nil - } - // "{X}" + "{Y}" => "{X,Y}" - res := append(structBytes[:len(structBytes)-1], ',') - res = append(res, mapBytes[1:]...) - return res, nil -} - -// unmarshalStructWithMap is the inverse of marshalStructWithMap. -// T has the same restrictions as in that function. -func unmarshalStructWithMap[T any](data []byte, v *T, mapField string) error { - // Unmarshal into the struct, ignoring unknown fields. - if err := json.Unmarshal(data, v); err != nil { - return err - } - // Unmarshal into the map. - m := map[string]any{} - if err := json.Unmarshal(data, &m); err != nil { - return err - } - // Delete from the map the fields of the struct. - for n := range jsonNames(reflect.TypeFor[T]()) { - delete(m, n) - } - if len(m) != 0 { - reflect.ValueOf(v).Elem().FieldByName(mapField).Set(reflect.ValueOf(m)) - } - return nil -} - -var jsonNamesMap sync.Map // from reflect.Type to map[string]bool - -// jsonNames returns the set of JSON object keys that t will marshal into. -// t must be a struct type. -func jsonNames(t reflect.Type) map[string]bool { - // Lock not necessary: at worst we'll duplicate work. - if val, ok := jsonNamesMap.Load(t); ok { - return val.(map[string]bool) - } - m := map[string]bool{} - for i := range t.NumField() { - info := util.FieldJSONInfo(t.Field(i)) - if !info.Omit { - m[info.Name] = true - } - } - jsonNamesMap.Store(t, m) - return m -} diff --git a/mcp/util_test.go b/mcp/util_test.go deleted file mode 100644 index f2cb0f5c..00000000 --- a/mcp/util_test.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package mcp - -import ( - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" -) - -func TestMarshalStructWithMap(t *testing.T) { - type S struct { - A int - B string `json:"b,omitempty"` - u bool - M map[string]any `json:",omitempty"` - } - t.Run("basic", func(t *testing.T) { - s := S{A: 1, B: "two", M: map[string]any{"!@#": true}} - got, err := marshalStructWithMap(&s, "M") - if err != nil { - t.Fatal(err) - } - want := `{"A":1,"b":"two","!@#":true}` - if g := string(got); g != want { - t.Errorf("\ngot %s\nwant %s", g, want) - } - - var un S - if err := unmarshalStructWithMap(got, &un, "M"); err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(s, un, cmpopts.IgnoreUnexported(S{})); diff != "" { - t.Errorf("mismatch (-want, +got):\n%s", diff) - } - }) - t.Run("duplicate", func(t *testing.T) { - s := S{A: 1, B: "two", M: map[string]any{"b": "dup"}} - _, err := marshalStructWithMap(&s, "M") - if err == nil || !strings.Contains(err.Error(), "duplicate") { - t.Errorf("got %v, want error with 'duplicate'", err) - } - }) -} From 5bd02a3c0451110e8e01a56b9fcfeb048c560a92 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Wed, 6 Aug 2025 13:02:10 -0400 Subject: [PATCH 075/221] jsonschema: copy wrapf into jsonschema/util.go (#254) jsonschema depends on wrapf in mcp/util.go, copy this into jsonschema for the code move. --- jsonschema/json_pointer.go | 4 +--- jsonschema/util.go | 7 +++++++ jsonschema/validate.go | 6 ++---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/jsonschema/json_pointer.go b/jsonschema/json_pointer.go index 7310b9b4..d7eb4a9a 100644 --- a/jsonschema/json_pointer.go +++ b/jsonschema/json_pointer.go @@ -26,8 +26,6 @@ import ( "reflect" "strconv" "strings" - - "github.com/modelcontextprotocol/go-sdk/internal/util" ) var ( @@ -71,7 +69,7 @@ func parseJSONPointer(ptr string) (segments []string, err error) { // This implementation suffices for JSON Schema: pointers are applied only to Schemas, // and refer only to Schemas. func dereferenceJSONPointer(s *Schema, sptr string) (_ *Schema, err error) { - defer util.Wrapf(&err, "JSON Pointer %q", sptr) + defer wrapf(&err, "JSON Pointer %q", sptr) segments, err := parseJSONPointer(sptr) if err != nil { diff --git a/jsonschema/util.go b/jsonschema/util.go index f5291536..25b916cd 100644 --- a/jsonschema/util.go +++ b/jsonschema/util.go @@ -454,3 +454,10 @@ func fieldJSONInfo(f reflect.StructField) jsonInfo { } return info } + +// wrapf wraps *errp with the given formatted message if *errp is not nil. +func wrapf(errp *error, format string, args ...any) { + if *errp != nil { + *errp = fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), *errp) + } +} diff --git a/jsonschema/validate.go b/jsonschema/validate.go index ca4f340c..99ddd3b8 100644 --- a/jsonschema/validate.go +++ b/jsonschema/validate.go @@ -16,8 +16,6 @@ import ( "strings" "sync" "unicode/utf8" - - "github.com/modelcontextprotocol/go-sdk/internal/util" ) // The value of the "$schema" keyword for the version that we can validate. @@ -74,7 +72,7 @@ type state struct { // validate validates the reflected value of the instance. func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *annotations) (err error) { - defer util.Wrapf(&err, "validating %s", st.rs.schemaString(schema)) + defer wrapf(&err, "validating %s", st.rs.schemaString(schema)) // Maintain a stack for dynamic schema resolution. st.stack = append(st.stack, schema) // push @@ -613,7 +611,7 @@ func (rs *Resolved) ApplyDefaults(instancep any) error { // Leave this as a potentially recursive helper function, because we'll surely want // to apply defaults on sub-schemas someday. func (st *state) applyDefaults(instancep reflect.Value, schema *Schema) (err error) { - defer util.Wrapf(&err, "applyDefaults: schema %s, instance %v", st.rs.schemaString(schema), instancep) + defer wrapf(&err, "applyDefaults: schema %s, instance %v", st.rs.schemaString(schema), instancep) schemaInfo := st.rs.resolvedInfos[schema] instance := instancep.Elem() From 4608401cc4b146e934c9f7e06da63a45a57964ef Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 7 Aug 2025 16:19:14 +0000 Subject: [PATCH 076/221] mcp: simplify ServerSession.initialize Make the simplifications described in #222. I don't know why we originally had such fine granularity of locking: perhaps we were going to execute 'oninitialized' callbacks. Also update the TODO with a bit more information about jsonrpc2 behavior. Fixes #222 --- mcp/server.go | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index c8878da3..b34cfe8b 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -753,19 +753,16 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) } ss.mu.Lock() + defer ss.mu.Unlock() ss.initializeParams = params - ss.mu.Unlock() // Mark the connection as initialized when this method exits. - // TODO: Technically, the server should not be considered initialized until it has - // *responded*, but we don't have adequate visibility into the jsonrpc2 - // connection to implement that easily. In any case, once we've initialized - // here, we can handle requests. - defer func() { - ss.mu.Lock() - ss.initialized = true - ss.mu.Unlock() - }() + // TODO(#26): Technically, the server should not be considered initialized + // until it has *responded*, but since jsonrpc2 is currently serialized we + // can mark the session as initialized here. If we ever implement a + // concurrency model (#26), we need to guarantee that initialize is not + // handled concurrently to other requests. + ss.initialized = true // If we support the client's version, reply with it. Otherwise, reply with our // latest version. From 5ab63feefa3f57eff213305c9c4714327639b7ac Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 7 Aug 2025 18:46:23 +0000 Subject: [PATCH 077/221] mcp: don't mark the server sesion as initialized prematurely Update ServerSession so that the session is not marked as initialized until 'notifications/initialized' is actually received from the client. Include a new test of this lifecycle strictness. Fixes #225 --- mcp/server.go | 40 +++++++------ .../conformance/server/bad_requests.txtar | 11 ++-- .../conformance/server/lifecycle.txtar | 57 +++++++++++++++++++ mcp/testdata/conformance/server/prompts.txtar | 2 + .../conformance/server/resources.txtar | 1 + mcp/testdata/conformance/server/tools.txtar | 1 + 6 files changed, 90 insertions(+), 22 deletions(-) create mode 100644 mcp/testdata/conformance/server/lifecycle.txtar diff --git a/mcp/server.go b/mcp/server.go index b34cfe8b..e69a872e 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -567,11 +567,25 @@ func (s *Server) Connect(ctx context.Context, t Transport) (*ServerSession, erro return connect(ctx, t, s) } -func (s *Server) callInitializedHandler(ctx context.Context, ss *ServerSession, params *InitializedParams) (Result, error) { - if s.opts.KeepAlive > 0 { - ss.startKeepalive(s.opts.KeepAlive) +func (ss *ServerSession) initialized(ctx context.Context, params *InitializedParams) (Result, error) { + if ss.server.opts.KeepAlive > 0 { + ss.startKeepalive(ss.server.opts.KeepAlive) } - return callNotificationHandler(ctx, s.opts.InitializedHandler, ss, params) + ss.mu.Lock() + hasParams := ss.initializeParams != nil + wasInitialized := ss._initialized + if hasParams { + ss._initialized = true + } + ss.mu.Unlock() + + if !hasParams { + return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize) + } + if wasInitialized { + return nil, fmt.Errorf("duplicate %q received", notificationInitialized) + } + return callNotificationHandler(ctx, ss.server.opts.InitializedHandler, ss, params) } func (s *Server) callRootsListChangedHandler(ctx context.Context, ss *ServerSession, params *RootsListChangedParams) (Result, error) { @@ -603,7 +617,7 @@ type ServerSession struct { mu sync.Mutex logLevel LoggingLevel initializeParams *InitializeParams - initialized bool + _initialized bool keepaliveCancel context.CancelFunc } @@ -702,7 +716,7 @@ var serverMethodInfos = map[string]methodInfo{ methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel), 0), methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe), 0), methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe), 0), - notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler), notification|missingParamsOK), + notificationInitialized: newMethodInfo(sessionMethod((*ServerSession).initialized), notification|missingParamsOK), notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK), notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler), notification), } @@ -729,13 +743,13 @@ func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn } // handle invokes the method described by the given JSON RPC request. func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { ss.mu.Lock() - initialized := ss.initialized + initialized := ss._initialized ss.mu.Unlock() // From the spec: // "The client SHOULD NOT send requests other than pings before the server // has responded to the initialize request." switch req.Method { - case "initialize", "ping": + case methodInitialize, methodPing, notificationInitialized: default: if !initialized { return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) @@ -753,16 +767,8 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) } ss.mu.Lock() - defer ss.mu.Unlock() ss.initializeParams = params - - // Mark the connection as initialized when this method exits. - // TODO(#26): Technically, the server should not be considered initialized - // until it has *responded*, but since jsonrpc2 is currently serialized we - // can mark the session as initialized here. If we ever implement a - // concurrency model (#26), we need to guarantee that initialize is not - // handled concurrently to other requests. - ss.initialized = true + ss.mu.Unlock() // If we support the client's version, reply with it. Otherwise, reply with our // latest version. diff --git a/mcp/testdata/conformance/server/bad_requests.txtar b/mcp/testdata/conformance/server/bad_requests.txtar index e9e9d483..44816189 100644 --- a/mcp/testdata/conformance/server/bad_requests.txtar +++ b/mcp/testdata/conformance/server/bad_requests.txtar @@ -38,11 +38,12 @@ code_review "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } } } -{"jsonrpc":"2.0", "id": 3, "method": "notifications/initialized"} -{"jsonrpc":"2.0", "method":"ping"} -{"jsonrpc":"2.0", "id": 4, "method": "logging/setLevel"} -{"jsonrpc":"2.0", "id": 5, "method": "completion/complete"} -{"jsonrpc":"2.0", "id": 4, "method": "logging/setLevel", "params": null} +{ "jsonrpc":"2.0", "id": 3, "method": "notifications/initialized" } +{ "jsonrpc":"2.0", "method": "notifications/initialized" } +{ "jsonrpc":"2.0", "method":"ping" } +{ "jsonrpc":"2.0", "id": 4, "method": "logging/setLevel" } +{ "jsonrpc":"2.0", "id": 5, "method": "completion/complete" } +{ "jsonrpc":"2.0", "id": 4, "method": "logging/setLevel", "params": null } -- server -- { diff --git a/mcp/testdata/conformance/server/lifecycle.txtar b/mcp/testdata/conformance/server/lifecycle.txtar new file mode 100644 index 00000000..eba287e0 --- /dev/null +++ b/mcp/testdata/conformance/server/lifecycle.txtar @@ -0,0 +1,57 @@ +This test checks that the server obeys the rules for initialization lifecycle, +and rejects non-ping requests until 'initialized' is received. + +See also modelcontextprotocol/go-sdk#225. + +-- client -- +{ "jsonrpc":"2.0", "method": "notifications/initialized" } +{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } + } +} +{ "jsonrpc":"2.0", "id": 1, "method":"ping" } +{ "jsonrpc": "2.0", "id": 2, "method": "tools/list" } +{ "jsonrpc":"2.0", "method": "notifications/initialized" } +{ "jsonrpc": "2.0", "id": 3, "method": "tools/list" } + +-- server -- +{ + "jsonrpc": "2.0", + "id": 1, + "result": { + "capabilities": { + "logging": {} + }, + "protocolVersion": "2024-11-05", + "serverInfo": { + "name": "testServer", + "version": "v1.0.0" + } + } +} +{ + "jsonrpc": "2.0", + "id": 1, + "result": {} +} +{ + "jsonrpc": "2.0", + "id": 2, + "error": { + "code": 0, + "message": "method \"tools/list\" is invalid during session initialization" + } +} +{ + "jsonrpc": "2.0", + "id": 3, + "result": { + "tools": [] + } +} diff --git a/mcp/testdata/conformance/server/prompts.txtar b/mcp/testdata/conformance/server/prompts.txtar index 3fd036e6..fdaf7932 100644 --- a/mcp/testdata/conformance/server/prompts.txtar +++ b/mcp/testdata/conformance/server/prompts.txtar @@ -18,9 +18,11 @@ code_review "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } } } +{ "jsonrpc":"2.0", "method": "notifications/initialized" } { "jsonrpc": "2.0", "id": 2, "method": "tools/list" } { "jsonrpc": "2.0", "id": 4, "method": "prompts/list" } { "jsonrpc": "2.0", "id": 5, "method": "prompts/get" } + -- server -- { "jsonrpc": "2.0", diff --git a/mcp/testdata/conformance/server/resources.txtar b/mcp/testdata/conformance/server/resources.txtar index ae2e23cb..314817b8 100644 --- a/mcp/testdata/conformance/server/resources.txtar +++ b/mcp/testdata/conformance/server/resources.txtar @@ -21,6 +21,7 @@ info.txt "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } } } +{ "jsonrpc":"2.0", "method": "notifications/initialized" } { "jsonrpc": "2.0", "id": 2, "method": "resources/list" } { "jsonrpc": "2.0", "id": 3, diff --git a/mcp/testdata/conformance/server/tools.txtar b/mcp/testdata/conformance/server/tools.txtar index 01fdb266..29dfdc18 100644 --- a/mcp/testdata/conformance/server/tools.txtar +++ b/mcp/testdata/conformance/server/tools.txtar @@ -20,6 +20,7 @@ greet "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } } } +{"jsonrpc":"2.0", "method": "notifications/initialized"} { "jsonrpc": "2.0", "id": 2, "method": "tools/list" } { "jsonrpc": "2.0", "id": 3, "method": "resources/list" } { "jsonrpc": "2.0", "id": 4, "method": "prompts/list" } From e9e0da8dac01a65a54fc66eb4955f534bf742991 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Thu, 7 Aug 2025 15:30:06 -0400 Subject: [PATCH 078/221] jsonschema: remove jsonschema code and depend on google/jsonschema-go (#262) This CL prunes the jsonschema directory and updates the go.mod. For #28 --- go.mod | 1 + go.sum | 2 + jsonschema/LICENSE | 21 - jsonschema/README.md | 39 - jsonschema/annotations.go | 76 - jsonschema/doc.go | 101 -- jsonschema/infer.go | 251 --- jsonschema/infer_test.go | 340 ---- jsonschema/json_pointer.go | 146 -- jsonschema/json_pointer_test.go | 78 - .../draft2020-12/meta/applicator.json | 45 - .../draft2020-12/meta/content.json | 14 - .../meta-schemas/draft2020-12/meta/core.json | 48 - .../draft2020-12/meta/format-annotation.json | 11 - .../draft2020-12/meta/meta-data.json | 34 - .../draft2020-12/meta/unevaluated.json | 12 - .../draft2020-12/meta/validation.json | 95 - .../meta-schemas/draft2020-12/schema.json | 58 - jsonschema/resolve.go | 548 ------ jsonschema/resolve_test.go | 218 --- jsonschema/schema.go | 436 ----- jsonschema/schema_test.go | 175 -- jsonschema/testdata/draft2020-12/README.md | 15 - .../draft2020-12/additionalProperties.json | 219 --- jsonschema/testdata/draft2020-12/allOf.json | 312 ---- jsonschema/testdata/draft2020-12/anchor.json | 120 -- jsonschema/testdata/draft2020-12/anyOf.json | 203 --- .../testdata/draft2020-12/boolean_schema.json | 104 -- jsonschema/testdata/draft2020-12/const.json | 387 ---- .../testdata/draft2020-12/contains.json | 176 -- jsonschema/testdata/draft2020-12/default.json | 82 - jsonschema/testdata/draft2020-12/defs.json | 21 - .../draft2020-12/dependentRequired.json | 152 -- .../draft2020-12/dependentSchemas.json | 171 -- .../testdata/draft2020-12/dynamicRef.json | 815 --------- jsonschema/testdata/draft2020-12/enum.json | 358 ---- .../draft2020-12/exclusiveMaximum.json | 31 - .../draft2020-12/exclusiveMinimum.json | 31 - .../testdata/draft2020-12/if-then-else.json | 268 --- .../draft2020-12/infinite-loop-detection.json | 37 - jsonschema/testdata/draft2020-12/items.json | 304 ---- .../testdata/draft2020-12/maxContains.json | 102 -- .../testdata/draft2020-12/maxItems.json | 50 - .../testdata/draft2020-12/maxLength.json | 55 - .../testdata/draft2020-12/maxProperties.json | 79 - jsonschema/testdata/draft2020-12/maximum.json | 60 - .../testdata/draft2020-12/minContains.json | 224 --- .../testdata/draft2020-12/minItems.json | 50 - .../testdata/draft2020-12/minLength.json | 55 - .../testdata/draft2020-12/minProperties.json | 60 - jsonschema/testdata/draft2020-12/minimum.json | 75 - .../testdata/draft2020-12/multipleOf.json | 97 - jsonschema/testdata/draft2020-12/not.json | 301 ---- jsonschema/testdata/draft2020-12/oneOf.json | 293 --- jsonschema/testdata/draft2020-12/pattern.json | 65 - .../draft2020-12/patternProperties.json | 176 -- .../testdata/draft2020-12/prefixItems.json | 104 -- .../testdata/draft2020-12/properties.json | 242 --- .../testdata/draft2020-12/propertyNames.json | 168 -- jsonschema/testdata/draft2020-12/ref.json | 1052 ----------- .../testdata/draft2020-12/refRemote.json | 342 ---- .../testdata/draft2020-12/required.json | 158 -- jsonschema/testdata/draft2020-12/type.json | 501 ------ .../draft2020-12/unevaluatedItems.json | 798 -------- .../draft2020-12/unevaluatedProperties.json | 1601 ----------------- .../testdata/draft2020-12/uniqueItems.json | 419 ----- jsonschema/testdata/remotes/README.md | 4 - .../remotes/different-id-ref-string.json | 5 - .../baseUriChange/folderInteger.json | 4 - .../baseUriChangeFolder/folderInteger.json | 4 - .../folderInteger.json | 4 - .../draft2020-12/detached-dynamicref.json | 13 - .../remotes/draft2020-12/detached-ref.json | 13 - .../draft2020-12/extendible-dynamic-ref.json | 21 - .../draft2020-12/format-assertion-false.json | 13 - .../draft2020-12/format-assertion-true.json | 13 - .../remotes/draft2020-12/integer.json | 4 - .../locationIndependentIdentifier.json | 12 - .../metaschema-no-validation.json | 13 - .../metaschema-optional-vocabulary.json | 14 - .../remotes/draft2020-12/name-defs.json | 16 - .../draft2020-12/nested/foo-ref-string.json | 7 - .../remotes/draft2020-12/nested/string.json | 4 - .../remotes/draft2020-12/prefixItems.json | 7 - .../remotes/draft2020-12/ref-and-defs.json | 12 - .../remotes/draft2020-12/subSchemas.json | 11 - .../testdata/remotes/draft2020-12/tree.json | 17 - .../nested-absolute-ref-to-string.json | 9 - .../testdata/remotes/urn-ref-string.json | 5 - jsonschema/util.go | 463 ----- jsonschema/util_test.go | 214 --- jsonschema/validate.go | 758 -------- jsonschema/validate_test.go | 294 --- mcp/client_list_test.go | 2 +- mcp/client_test.go | 2 +- mcp/example_middleware_test.go | 2 +- mcp/features_test.go | 2 +- mcp/mcp_test.go | 2 +- mcp/protocol.go | 2 +- mcp/server_test.go | 2 +- mcp/streamable_test.go | 2 +- mcp/tool.go | 2 +- mcp/tool_test.go | 2 +- 103 files changed, 13 insertions(+), 15643 deletions(-) delete mode 100644 jsonschema/LICENSE delete mode 100644 jsonschema/README.md delete mode 100644 jsonschema/annotations.go delete mode 100644 jsonschema/doc.go delete mode 100644 jsonschema/infer.go delete mode 100644 jsonschema/infer_test.go delete mode 100644 jsonschema/json_pointer.go delete mode 100644 jsonschema/json_pointer_test.go delete mode 100644 jsonschema/meta-schemas/draft2020-12/meta/applicator.json delete mode 100644 jsonschema/meta-schemas/draft2020-12/meta/content.json delete mode 100644 jsonschema/meta-schemas/draft2020-12/meta/core.json delete mode 100644 jsonschema/meta-schemas/draft2020-12/meta/format-annotation.json delete mode 100644 jsonschema/meta-schemas/draft2020-12/meta/meta-data.json delete mode 100644 jsonschema/meta-schemas/draft2020-12/meta/unevaluated.json delete mode 100644 jsonschema/meta-schemas/draft2020-12/meta/validation.json delete mode 100644 jsonschema/meta-schemas/draft2020-12/schema.json delete mode 100644 jsonschema/resolve.go delete mode 100644 jsonschema/resolve_test.go delete mode 100644 jsonschema/schema.go delete mode 100644 jsonschema/schema_test.go delete mode 100644 jsonschema/testdata/draft2020-12/README.md delete mode 100644 jsonschema/testdata/draft2020-12/additionalProperties.json delete mode 100644 jsonschema/testdata/draft2020-12/allOf.json delete mode 100644 jsonschema/testdata/draft2020-12/anchor.json delete mode 100644 jsonschema/testdata/draft2020-12/anyOf.json delete mode 100644 jsonschema/testdata/draft2020-12/boolean_schema.json delete mode 100644 jsonschema/testdata/draft2020-12/const.json delete mode 100644 jsonschema/testdata/draft2020-12/contains.json delete mode 100644 jsonschema/testdata/draft2020-12/default.json delete mode 100644 jsonschema/testdata/draft2020-12/defs.json delete mode 100644 jsonschema/testdata/draft2020-12/dependentRequired.json delete mode 100644 jsonschema/testdata/draft2020-12/dependentSchemas.json delete mode 100644 jsonschema/testdata/draft2020-12/dynamicRef.json delete mode 100644 jsonschema/testdata/draft2020-12/enum.json delete mode 100644 jsonschema/testdata/draft2020-12/exclusiveMaximum.json delete mode 100644 jsonschema/testdata/draft2020-12/exclusiveMinimum.json delete mode 100644 jsonschema/testdata/draft2020-12/if-then-else.json delete mode 100644 jsonschema/testdata/draft2020-12/infinite-loop-detection.json delete mode 100644 jsonschema/testdata/draft2020-12/items.json delete mode 100644 jsonschema/testdata/draft2020-12/maxContains.json delete mode 100644 jsonschema/testdata/draft2020-12/maxItems.json delete mode 100644 jsonschema/testdata/draft2020-12/maxLength.json delete mode 100644 jsonschema/testdata/draft2020-12/maxProperties.json delete mode 100644 jsonschema/testdata/draft2020-12/maximum.json delete mode 100644 jsonschema/testdata/draft2020-12/minContains.json delete mode 100644 jsonschema/testdata/draft2020-12/minItems.json delete mode 100644 jsonschema/testdata/draft2020-12/minLength.json delete mode 100644 jsonschema/testdata/draft2020-12/minProperties.json delete mode 100644 jsonschema/testdata/draft2020-12/minimum.json delete mode 100644 jsonschema/testdata/draft2020-12/multipleOf.json delete mode 100644 jsonschema/testdata/draft2020-12/not.json delete mode 100644 jsonschema/testdata/draft2020-12/oneOf.json delete mode 100644 jsonschema/testdata/draft2020-12/pattern.json delete mode 100644 jsonschema/testdata/draft2020-12/patternProperties.json delete mode 100644 jsonschema/testdata/draft2020-12/prefixItems.json delete mode 100644 jsonschema/testdata/draft2020-12/properties.json delete mode 100644 jsonschema/testdata/draft2020-12/propertyNames.json delete mode 100644 jsonschema/testdata/draft2020-12/ref.json delete mode 100644 jsonschema/testdata/draft2020-12/refRemote.json delete mode 100644 jsonschema/testdata/draft2020-12/required.json delete mode 100644 jsonschema/testdata/draft2020-12/type.json delete mode 100644 jsonschema/testdata/draft2020-12/unevaluatedItems.json delete mode 100644 jsonschema/testdata/draft2020-12/unevaluatedProperties.json delete mode 100644 jsonschema/testdata/draft2020-12/uniqueItems.json delete mode 100644 jsonschema/testdata/remotes/README.md delete mode 100644 jsonschema/testdata/remotes/different-id-ref-string.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/baseUriChange/folderInteger.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolder/folderInteger.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolderInSubschema/folderInteger.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/detached-dynamicref.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/detached-ref.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/extendible-dynamic-ref.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/format-assertion-false.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/format-assertion-true.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/integer.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/locationIndependentIdentifier.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/metaschema-no-validation.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/metaschema-optional-vocabulary.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/name-defs.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/nested/foo-ref-string.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/nested/string.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/prefixItems.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/ref-and-defs.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/subSchemas.json delete mode 100644 jsonschema/testdata/remotes/draft2020-12/tree.json delete mode 100644 jsonschema/testdata/remotes/nested-absolute-ref-to-string.json delete mode 100644 jsonschema/testdata/remotes/urn-ref-string.json delete mode 100644 jsonschema/util.go delete mode 100644 jsonschema/util_test.go delete mode 100644 jsonschema/validate.go delete mode 100644 jsonschema/validate_test.go diff --git a/go.mod b/go.mod index 9bf8c151..17bddeb6 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.23.0 require ( github.com/google/go-cmp v0.7.0 + github.com/google/jsonschema-go v0.2.0 github.com/yosida95/uritemplate/v3 v3.0.2 golang.org/x/tools v0.34.0 ) diff --git a/go.sum b/go.sum index 7d2f581d..a2edf9ad 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.2.0 h1:Uh19091iHC56//WOsAd1oRg6yy1P9BpSvpjOL6RcjLQ= +github.com/google/jsonschema-go v0.2.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= diff --git a/jsonschema/LICENSE b/jsonschema/LICENSE deleted file mode 100644 index 508be926..00000000 --- a/jsonschema/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2025 Go MCP SDK Authors - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/jsonschema/README.md b/jsonschema/README.md deleted file mode 100644 index f316bedd..00000000 --- a/jsonschema/README.md +++ /dev/null @@ -1,39 +0,0 @@ -TODO: this file should live at the root of the jsonschema-go module, -above the jsonschema package. - -# JSON Schema for GO - -This module implements the [JSON Schema](https://json-schema.org/) specification. -The `jsonschema` package supports creating schemas, validating JSON values -against a schema, and inferring a schema from a Go struct. See the package -documentation for usage. - -## Contributing - -This module welcomes external contributions. -It has no dependencies outside of the standard library, and can be built with -the standard Go toolchain. Run `go test ./...` at the module root to run all -the tests. - -## Issues - -This project uses the [GitHub issue -tracker](https://github.com/TODO/jsonschema-go/issues) for bug reports, feature requests, and other issues. - -Please [report -bugs](https://github.com/TODO/jsonschema-go/issues/new). If the SDK is -not working as you expected, it is likely due to a bug or inadequate -documentation, and reporting an issue will help us address this shortcoming. - -When reporting a bug, make sure to answer these five questions: - -1. What did you do? -2. What did you see? -3. What did you expect to see? -4. What version of the Go MCP SDK are you using? -5. What version of Go are you using (`go version`)? - -## License - -This project is licensed under the MIT license. See the LICENSE file for details. - diff --git a/jsonschema/annotations.go b/jsonschema/annotations.go deleted file mode 100644 index a7ede1c6..00000000 --- a/jsonschema/annotations.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import "maps" - -// An annotations tracks certain properties computed by keywords that are used by validation. -// ("Annotation" is the spec's term.) -// In particular, the unevaluatedItems and unevaluatedProperties keywords need to know which -// items and properties were evaluated (validated successfully). -type annotations struct { - allItems bool // all items were evaluated - endIndex int // 1+largest index evaluated by prefixItems - evaluatedIndexes map[int]bool // set of indexes evaluated by contains - allProperties bool // all properties were evaluated - evaluatedProperties map[string]bool // set of properties evaluated by various keywords -} - -// noteIndex marks i as evaluated. -func (a *annotations) noteIndex(i int) { - if a.evaluatedIndexes == nil { - a.evaluatedIndexes = map[int]bool{} - } - a.evaluatedIndexes[i] = true -} - -// noteEndIndex marks items with index less than end as evaluated. -func (a *annotations) noteEndIndex(end int) { - if end > a.endIndex { - a.endIndex = end - } -} - -// noteProperty marks prop as evaluated. -func (a *annotations) noteProperty(prop string) { - if a.evaluatedProperties == nil { - a.evaluatedProperties = map[string]bool{} - } - a.evaluatedProperties[prop] = true -} - -// noteProperties marks all the properties in props as evaluated. -func (a *annotations) noteProperties(props map[string]bool) { - a.evaluatedProperties = merge(a.evaluatedProperties, props) -} - -// merge adds b's annotations to a. -// a must not be nil. -func (a *annotations) merge(b *annotations) { - if b == nil { - return - } - if b.allItems { - a.allItems = true - } - if b.endIndex > a.endIndex { - a.endIndex = b.endIndex - } - a.evaluatedIndexes = merge(a.evaluatedIndexes, b.evaluatedIndexes) - if b.allProperties { - a.allProperties = true - } - a.evaluatedProperties = merge(a.evaluatedProperties, b.evaluatedProperties) -} - -// merge adds t's keys to s and returns s. -// If s is nil, it returns a copy of t. -func merge[K comparable](s, t map[K]bool) map[K]bool { - if s == nil { - return maps.Clone(t) - } - maps.Copy(s, t) - return s -} diff --git a/jsonschema/doc.go b/jsonschema/doc.go deleted file mode 100644 index 0f0ba441..00000000 --- a/jsonschema/doc.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -/* -Package jsonschema is an implementation of the [JSON Schema specification], -a JSON-based format for describing the structure of JSON data. -The package can be used to read schemas for code generation, and to validate -data using the draft 2020-12 specification. Validation with other drafts -or custom meta-schemas is not supported. - -Construct a [Schema] as you would any Go struct (for example, by writing -a struct literal), or unmarshal a JSON schema into a [Schema] in the usual -way (with [encoding/json], for instance). It can then be used for code -generation or other purposes without further processing. -You can also infer a schema from a Go struct. - -# Resolution - -A Schema can refer to other schemas, both inside and outside itself. These -references must be resolved before a schema can be used for validation. -Call [Schema.Resolve] to obtain a resolved schema (called a [Resolved]). -If the schema has external references, pass a [ResolveOptions] with a [Loader] -to load them. To validate default values in a schema, set -[ResolveOptions.ValidateDefaults] to true. - -# Validation - -Call [Resolved.Validate] to validate a JSON value. The value must be a -Go value that looks like the result of unmarshaling a JSON value into an -[any] or a struct. For example, the JSON value - - {"name": "Al", "scores": [90, 80, 100]} - -could be represented as the Go value - - map[string]any{ - "name": "Al", - "scores": []any{90, 80, 100}, - } - -or as a value of this type: - - type Player struct { - Name string `json:"name"` - Scores []int `json:"scores"` - } - -# Inference - -The [For] function returns a [Schema] describing the given Go type. -Each field in the struct becomes a property of the schema. -The values of "json" tags are respected: the field's property name is taken -from the tag, and fields omitted from the JSON are omitted from the schema as -well. -For example, `jsonschema.For[Player]()` returns this schema: - - { - "properties": { - "name": { - "type": "string" - }, - "scores": { - "type": "array", - "items": {"type": "integer"} - } - "required": ["name", "scores"], - "additionalProperties": {"not": {}} - } - } - -Use the "jsonschema" struct tag to provide a description for the property: - - type Player struct { - Name string `json:"name" jsonschema:"player name"` - Scores []int `json:"scores" jsonschema:"scores of player's games"` - } - -# Deviations from the specification - -Regular expressions are processed with Go's regexp package, which differs -from ECMA 262, most significantly in not supporting back-references. -See [this table of differences] for more. - -The "format" keyword described in [section 7 of the validation spec] is recorded -in the Schema, but is ignored during validation. -It does not even produce [annotations]. -Use the "pattern" keyword instead: it will work more reliably across JSON Schema -implementations. See [learnjsonschema.com] for more recommendations about "format". - -The content keywords described in [section 8 of the validation spec] -are recorded in the schema, but ignored during validation. - -[JSON Schema specification]: https://json-schema.org -[section 7 of the validation spec]: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.7 -[section 8 of the validation spec]: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.8 -[learnjsonschema.com]: https://www.learnjsonschema.com/2020-12/format-annotation/format/ -[this table of differences]: https://github.com/dlclark/regexp2?tab=readme-ov-file#compare-regexp-and-regexp2 -[annotations]: https://json-schema.org/draft/2020-12/json-schema-core#name-annotations -*/ -package jsonschema diff --git a/jsonschema/infer.go b/jsonschema/infer.go deleted file mode 100644 index aa01fcd2..00000000 --- a/jsonschema/infer.go +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -// This file contains functions that infer a schema from a Go type. - -package jsonschema - -import ( - "fmt" - "log/slog" - "maps" - "math/big" - "reflect" - "regexp" - "time" -) - -// ForOptions are options for the [For] function. -type ForOptions struct { - // If IgnoreInvalidTypes is true, fields that can't be represented as a JSON - // Schema are ignored instead of causing an error. - // This allows callers to adjust the resulting schema using custom knowledge. - // For example, an interface type where all the possible implementations are - // known can be described with "oneof". - IgnoreInvalidTypes bool - - // TypeSchemas maps types to their schemas. - // If [For] encounters a type equal to a type of a key in this map, the - // corresponding value is used as the resulting schema (after cloning to - // ensure uniqueness). - // Types in this map override the default translations, as described - // in [For]'s documentation. - TypeSchemas map[any]*Schema -} - -// For constructs a JSON schema object for the given type argument. -// If non-nil, the provided options configure certain aspects of this contruction, -// described below. - -// It translates Go types into compatible JSON schema types, as follows. -// These defaults can be overridden by [ForOptions.TypeSchemas]. -// -// - Strings have schema type "string". -// - Bools have schema type "boolean". -// - Signed and unsigned integer types have schema type "integer". -// - Floating point types have schema type "number". -// - Slices and arrays have schema type "array", and a corresponding schema -// for items. -// - Maps with string key have schema type "object", and corresponding -// schema for additionalProperties. -// - Structs have schema type "object", and disallow additionalProperties. -// Their properties are derived from exported struct fields, using the -// struct field JSON name. Fields that are marked "omitempty" are -// considered optional; all other fields become required properties. -// - Some types in the standard library that implement json.Marshaler -// translate to schemas that match the values to which they marshal. -// For example, [time.Time] translates to the schema for strings. -// -// For will return an error if there is a cycle in the types. -// -// By default, For returns an error if t contains (possibly recursively) any of the -// following Go types, as they are incompatible with the JSON schema spec. -// If [ForOptions.IgnoreInvalidTypes] is true, then these types are ignored instead. -// - maps with key other than 'string' -// - function types -// - channel types -// - complex numbers -// - unsafe pointers -// -// This function recognizes struct field tags named "jsonschema". -// A jsonschema tag on a field is used as the description for the corresponding property. -// For future compatibility, descriptions must not start with "WORD=", where WORD is a -// sequence of non-whitespace characters. -func For[T any](opts *ForOptions) (*Schema, error) { - if opts == nil { - opts = &ForOptions{} - } - schemas := maps.Clone(initialSchemaMap) - // Add types from the options. They override the default ones. - for v, s := range opts.TypeSchemas { - schemas[reflect.TypeOf(v)] = s - } - s, err := forType(reflect.TypeFor[T](), map[reflect.Type]bool{}, opts.IgnoreInvalidTypes, schemas) - if err != nil { - var z T - return nil, fmt.Errorf("For[%T](): %w", z, err) - } - return s, nil -} - -// ForType is like [For], but takes a [reflect.Type] -func ForType(t reflect.Type, opts *ForOptions) (*Schema, error) { - schemas := maps.Clone(initialSchemaMap) - // Add types from the options. They override the default ones. - for v, s := range opts.TypeSchemas { - schemas[reflect.TypeOf(v)] = s - } - s, err := forType(t, map[reflect.Type]bool{}, opts.IgnoreInvalidTypes, schemas) - if err != nil { - return nil, fmt.Errorf("ForType(%s): %w", t, err) - } - return s, nil -} - -func forType(t reflect.Type, seen map[reflect.Type]bool, ignore bool, schemas map[reflect.Type]*Schema) (*Schema, error) { - // Follow pointers: the schema for *T is almost the same as for T, except that - // an explicit JSON "null" is allowed for the pointer. - allowNull := false - for t.Kind() == reflect.Pointer { - allowNull = true - t = t.Elem() - } - - // Check for cycles - // User defined types have a name, so we can skip those that are natively defined - if t.Name() != "" { - if seen[t] { - return nil, fmt.Errorf("cycle detected for type %v", t) - } - seen[t] = true - defer delete(seen, t) - } - - if s := schemas[t]; s != nil { - return s.CloneSchemas(), nil - } - - var ( - s = new(Schema) - err error - ) - - switch t.Kind() { - case reflect.Bool: - s.Type = "boolean" - - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, - reflect.Uintptr: - s.Type = "integer" - - case reflect.Float32, reflect.Float64: - s.Type = "number" - - case reflect.Interface: - // Unrestricted - - case reflect.Map: - if t.Key().Kind() != reflect.String { - if ignore { - return nil, nil // ignore - } - return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind()) - } - if t.Key().Kind() != reflect.String { - } - s.Type = "object" - s.AdditionalProperties, err = forType(t.Elem(), seen, ignore, schemas) - if err != nil { - return nil, fmt.Errorf("computing map value schema: %v", err) - } - if ignore && s.AdditionalProperties == nil { - // Ignore if the element type is invalid. - return nil, nil - } - - case reflect.Slice, reflect.Array: - s.Type = "array" - s.Items, err = forType(t.Elem(), seen, ignore, schemas) - if err != nil { - return nil, fmt.Errorf("computing element schema: %v", err) - } - if ignore && s.Items == nil { - // Ignore if the element type is invalid. - return nil, nil - } - if t.Kind() == reflect.Array { - s.MinItems = Ptr(t.Len()) - s.MaxItems = Ptr(t.Len()) - } - - case reflect.String: - s.Type = "string" - - case reflect.Struct: - s.Type = "object" - // no additional properties are allowed - s.AdditionalProperties = falseSchema() - - for i := range t.NumField() { - field := t.Field(i) - info := fieldJSONInfo(field) - if info.omit { - continue - } - if s.Properties == nil { - s.Properties = make(map[string]*Schema) - } - fs, err := forType(field.Type, seen, ignore, schemas) - if err != nil { - return nil, err - } - if ignore && fs == nil { - // Skip fields of invalid type. - continue - } - if tag, ok := field.Tag.Lookup("jsonschema"); ok { - if tag == "" { - return nil, fmt.Errorf("empty jsonschema tag on struct field %s.%s", t, field.Name) - } - if disallowedPrefixRegexp.MatchString(tag) { - return nil, fmt.Errorf("tag must not begin with 'WORD=': %q", tag) - } - fs.Description = tag - } - s.Properties[info.name] = fs - if !info.settings["omitempty"] && !info.settings["omitzero"] { - s.Required = append(s.Required, info.name) - } - } - - default: - if ignore { - // Ignore. - return nil, nil - } - return nil, fmt.Errorf("type %v is unsupported by jsonschema", t) - } - if allowNull && s.Type != "" { - s.Types = []string{"null", s.Type} - s.Type = "" - } - schemas[t] = s - return s, nil -} - -// initialSchemaMap holds types from the standard library that have MarshalJSON methods. -var initialSchemaMap = make(map[reflect.Type]*Schema) - -func init() { - ss := &Schema{Type: "string"} - initialSchemaMap[reflect.TypeFor[time.Time]()] = ss - initialSchemaMap[reflect.TypeFor[slog.Level]()] = ss - initialSchemaMap[reflect.TypeFor[big.Int]()] = &Schema{Types: []string{"null", "string"}} - initialSchemaMap[reflect.TypeFor[big.Rat]()] = ss - initialSchemaMap[reflect.TypeFor[big.Float]()] = ss -} - -// Disallow jsonschema tag values beginning "WORD=", for future expansion. -var disallowedPrefixRegexp = regexp.MustCompile("^[^ \t\n]*=") diff --git a/jsonschema/infer_test.go b/jsonschema/infer_test.go deleted file mode 100644 index 62bfbbbc..00000000 --- a/jsonschema/infer_test.go +++ /dev/null @@ -1,340 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema_test - -import ( - "log/slog" - "math/big" - "reflect" - "strings" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "github.com/modelcontextprotocol/go-sdk/jsonschema" -) - -type custom int - -func forType[T any](ignore bool) *jsonschema.Schema { - var s *jsonschema.Schema - var err error - - opts := &jsonschema.ForOptions{ - IgnoreInvalidTypes: ignore, - TypeSchemas: map[any]*jsonschema.Schema{ - custom(0): {Type: "custom"}, - }, - } - s, err = jsonschema.For[T](opts) - if err != nil { - panic(err) - } - return s -} - -func TestFor(t *testing.T) { - type schema = jsonschema.Schema - - type S struct { - B int `jsonschema:"bdesc"` - } - - type test struct { - name string - got *jsonschema.Schema - want *jsonschema.Schema - } - - tests := func(ignore bool) []test { - return []test{ - {"string", forType[string](ignore), &schema{Type: "string"}}, - {"int", forType[int](ignore), &schema{Type: "integer"}}, - {"int16", forType[int16](ignore), &schema{Type: "integer"}}, - {"uint32", forType[int16](ignore), &schema{Type: "integer"}}, - {"float64", forType[float64](ignore), &schema{Type: "number"}}, - {"bool", forType[bool](ignore), &schema{Type: "boolean"}}, - {"time", forType[time.Time](ignore), &schema{Type: "string"}}, - {"level", forType[slog.Level](ignore), &schema{Type: "string"}}, - {"bigint", forType[big.Int](ignore), &schema{Types: []string{"null", "string"}}}, - {"custom", forType[custom](ignore), &schema{Type: "custom"}}, - {"intmap", forType[map[string]int](ignore), &schema{ - Type: "object", - AdditionalProperties: &schema{Type: "integer"}, - }}, - {"anymap", forType[map[string]any](ignore), &schema{ - Type: "object", - AdditionalProperties: &schema{}, - }}, - { - "struct", - forType[struct { - F int `json:"f" jsonschema:"fdesc"` - G []float64 - P *bool `jsonschema:"pdesc"` - Skip string `json:"-"` - NoSkip string `json:",omitempty"` - unexported float64 - unexported2 int `json:"No"` - }](ignore), - &schema{ - Type: "object", - Properties: map[string]*schema{ - "f": {Type: "integer", Description: "fdesc"}, - "G": {Type: "array", Items: &schema{Type: "number"}}, - "P": {Types: []string{"null", "boolean"}, Description: "pdesc"}, - "NoSkip": {Type: "string"}, - }, - Required: []string{"f", "G", "P"}, - AdditionalProperties: falseSchema(), - }, - }, - { - "no sharing", - forType[struct{ X, Y int }](ignore), - &schema{ - Type: "object", - Properties: map[string]*schema{ - "X": {Type: "integer"}, - "Y": {Type: "integer"}, - }, - Required: []string{"X", "Y"}, - AdditionalProperties: falseSchema(), - }, - }, - { - "nested and embedded", - forType[struct { - A S - S - }](ignore), - &schema{ - Type: "object", - Properties: map[string]*schema{ - "A": { - Type: "object", - Properties: map[string]*schema{ - "B": {Type: "integer", Description: "bdesc"}, - }, - Required: []string{"B"}, - AdditionalProperties: falseSchema(), - }, - "S": { - Type: "object", - Properties: map[string]*schema{ - "B": {Type: "integer", Description: "bdesc"}, - }, - Required: []string{"B"}, - AdditionalProperties: falseSchema(), - }, - }, - Required: []string{"A", "S"}, - AdditionalProperties: falseSchema(), - }, - }, - } - } - - run := func(t *testing.T, tt test) { - if diff := cmp.Diff(tt.want, tt.got, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Fatalf("For mismatch (-want +got):\n%s", diff) - } - // These schemas should all resolve. - if _, err := tt.got.Resolve(nil); err != nil { - t.Fatalf("Resolving: %v", err) - } - } - - t.Run("strict", func(t *testing.T) { - for _, test := range tests(false) { - t.Run(test.name, func(t *testing.T) { run(t, test) }) - } - }) - - laxTests := append(tests(true), test{ - "ignore", - forType[struct { - A int - B map[int]int - C func() - }](true), - &schema{ - Type: "object", - Properties: map[string]*schema{ - "A": {Type: "integer"}, - }, - Required: []string{"A"}, - AdditionalProperties: falseSchema(), - }, - }) - t.Run("lax", func(t *testing.T) { - for _, test := range laxTests { - t.Run(test.name, func(t *testing.T) { run(t, test) }) - } - }) -} - -func TestForType(t *testing.T) { - type schema = jsonschema.Schema - - // ForType is virtually identical to For. Just test that options are handled properly. - opts := &jsonschema.ForOptions{ - IgnoreInvalidTypes: true, - TypeSchemas: map[any]*jsonschema.Schema{ - custom(0): {Type: "custom"}, - }, - } - - type S struct { - I int - F func() - C custom - } - got, err := jsonschema.ForType(reflect.TypeOf(S{}), opts) - if err != nil { - t.Fatal(err) - } - want := &schema{ - Type: "object", - Properties: map[string]*schema{ - "I": {Type: "integer"}, - "C": {Type: "custom"}, - }, - Required: []string{"I", "C"}, - AdditionalProperties: falseSchema(), - } - if diff := cmp.Diff(want, got, cmpopts.IgnoreUnexported(schema{})); diff != "" { - t.Fatalf("ForType mismatch (-want +got):\n%s", diff) - } -} - -func forErr[T any]() error { - _, err := jsonschema.For[T](nil) - return err -} - -func TestForErrors(t *testing.T) { - type ( - s1 struct { - Empty int `jsonschema:""` - } - s2 struct { - Bad int `jsonschema:"$foo=1,bar"` - } - ) - - for _, tt := range []struct { - got error - want string - }{ - {forErr[map[int]int](), "unsupported map key type"}, - {forErr[s1](), "empty jsonschema tag"}, - {forErr[s2](), "must not begin with"}, - {forErr[func()](), "unsupported"}, - } { - if tt.got == nil { - t.Errorf("got nil, want error containing %q", tt.want) - } else if !strings.Contains(tt.got.Error(), tt.want) { - t.Errorf("got %q\nwant it to contain %q", tt.got, tt.want) - } - } -} - -func TestForWithMutation(t *testing.T) { - // This test ensures that the cached schema is not mutated when the caller - // mutates the returned schema. - type S struct { - A int - } - type T struct { - A int `json:"A"` - B map[string]int - C []S - D [3]S - E *bool - } - s, err := jsonschema.For[T](nil) - if err != nil { - t.Fatalf("For: %v", err) - } - s.Required[0] = "mutated" - s.Properties["A"].Type = "mutated" - s.Properties["C"].Items.Type = "mutated" - s.Properties["D"].MaxItems = jsonschema.Ptr(10) - s.Properties["D"].MinItems = jsonschema.Ptr(10) - s.Properties["E"].Types[0] = "mutated" - - s2, err := jsonschema.For[T](nil) - if err != nil { - t.Fatalf("For: %v", err) - } - if s2.Properties["A"].Type == "mutated" { - t.Fatalf("ForWithMutation: expected A.Type to not be mutated") - } - if s2.Properties["B"].AdditionalProperties.Type == "mutated" { - t.Fatalf("ForWithMutation: expected B.AdditionalProperties.Type to not be mutated") - } - if s2.Properties["C"].Items.Type == "mutated" { - t.Fatalf("ForWithMutation: expected C.Items.Type to not be mutated") - } - if *s2.Properties["D"].MaxItems == 10 { - t.Fatalf("ForWithMutation: expected D.MaxItems to not be mutated") - } - if *s2.Properties["D"].MinItems == 10 { - t.Fatalf("ForWithMutation: expected D.MinItems to not be mutated") - } - if s2.Properties["E"].Types[0] == "mutated" { - t.Fatalf("ForWithMutation: expected E.Types[0] to not be mutated") - } - if s2.Required[0] == "mutated" { - t.Fatalf("ForWithMutation: expected Required[0] to not be mutated") - } -} - -type x struct { - Y y -} -type y struct { - X []x -} - -func TestForWithCycle(t *testing.T) { - type a []*a - type b1 struct{ b *b1 } // unexported field should be skipped - type b2 struct{ B *b2 } - type c1 struct{ c map[string]*c1 } // unexported field should be skipped - type c2 struct{ C map[string]*c2 } - - tests := []struct { - name string - shouldErr bool - fn func() error - }{ - {"slice alias (a)", true, func() error { _, err := jsonschema.For[a](nil); return err }}, - {"unexported self cycle (b1)", false, func() error { _, err := jsonschema.For[b1](nil); return err }}, - {"exported self cycle (b2)", true, func() error { _, err := jsonschema.For[b2](nil); return err }}, - {"unexported map self cycle (c1)", false, func() error { _, err := jsonschema.For[c1](nil); return err }}, - {"exported map self cycle (c2)", true, func() error { _, err := jsonschema.For[c2](nil); return err }}, - {"cross-cycle x -> y -> x", true, func() error { _, err := jsonschema.For[x](nil); return err }}, - {"cross-cycle y -> x -> y", true, func() error { _, err := jsonschema.For[y](nil); return err }}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - err := test.fn() - if test.shouldErr && err == nil { - t.Errorf("expected cycle error, got nil") - } - if !test.shouldErr && err != nil { - t.Errorf("unexpected error: %v", err) - } - }) - } -} - -func falseSchema() *jsonschema.Schema { - return &jsonschema.Schema{Not: &jsonschema.Schema{}} -} diff --git a/jsonschema/json_pointer.go b/jsonschema/json_pointer.go deleted file mode 100644 index d7eb4a9a..00000000 --- a/jsonschema/json_pointer.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -// This file implements JSON Pointers. -// A JSON Pointer is a path that refers to one JSON value within another. -// If the path is empty, it refers to the root value. -// Otherwise, it is a sequence of slash-prefixed strings, like "/points/1/x", -// selecting successive properties (for JSON objects) or items (for JSON arrays). -// For example, when applied to this JSON value: -// { -// "points": [ -// {"x": 1, "y": 2}, -// {"x": 3, "y": 4} -// ] -// } -// -// the JSON Pointer "/points/1/x" refers to the number 3. -// See the spec at https://datatracker.ietf.org/doc/html/rfc6901. - -package jsonschema - -import ( - "errors" - "fmt" - "reflect" - "strconv" - "strings" -) - -var ( - jsonPointerEscaper = strings.NewReplacer("~", "~0", "/", "~1") - jsonPointerUnescaper = strings.NewReplacer("~0", "~", "~1", "/") -) - -func escapeJSONPointerSegment(s string) string { - return jsonPointerEscaper.Replace(s) -} - -func unescapeJSONPointerSegment(s string) string { - return jsonPointerUnescaper.Replace(s) -} - -// parseJSONPointer splits a JSON Pointer into a sequence of segments. It doesn't -// convert strings to numbers, because that depends on the traversal: a segment -// is treated as a number when applied to an array, but a string when applied to -// an object. See section 4 of the spec. -func parseJSONPointer(ptr string) (segments []string, err error) { - if ptr == "" { - return nil, nil - } - if ptr[0] != '/' { - return nil, fmt.Errorf("JSON Pointer %q does not begin with '/'", ptr) - } - // Unlike file paths, consecutive slashes are not coalesced. - // Split is nicer than Cut here, because it gets a final "/" right. - segments = strings.Split(ptr[1:], "/") - if strings.Contains(ptr, "~") { - // Undo the simple escaping rules that allow one to include a slash in a segment. - for i := range segments { - segments[i] = unescapeJSONPointerSegment(segments[i]) - } - } - return segments, nil -} - -// dereferenceJSONPointer returns the Schema that sptr points to within s, -// or an error if none. -// This implementation suffices for JSON Schema: pointers are applied only to Schemas, -// and refer only to Schemas. -func dereferenceJSONPointer(s *Schema, sptr string) (_ *Schema, err error) { - defer wrapf(&err, "JSON Pointer %q", sptr) - - segments, err := parseJSONPointer(sptr) - if err != nil { - return nil, err - } - v := reflect.ValueOf(s) - for _, seg := range segments { - switch v.Kind() { - case reflect.Pointer: - v = v.Elem() - if !v.IsValid() { - return nil, errors.New("navigated to nil reference") - } - fallthrough // if valid, can only be a pointer to a Schema - - case reflect.Struct: - // The segment must refer to a field in a Schema. - if v.Type() != reflect.TypeFor[Schema]() { - return nil, fmt.Errorf("navigated to non-Schema %s", v.Type()) - } - v = lookupSchemaField(v, seg) - if !v.IsValid() { - return nil, fmt.Errorf("no schema field %q", seg) - } - case reflect.Slice, reflect.Array: - // The segment must be an integer without leading zeroes that refers to an item in the - // slice or array. - if seg == "-" { - return nil, errors.New("the JSON Pointer array segment '-' is not supported") - } - if len(seg) > 1 && seg[0] == '0' { - return nil, fmt.Errorf("segment %q has leading zeroes", seg) - } - n, err := strconv.Atoi(seg) - if err != nil { - return nil, fmt.Errorf("invalid int: %q", seg) - } - if n < 0 || n >= v.Len() { - return nil, fmt.Errorf("index %d is out of bounds for array of length %d", n, v.Len()) - } - v = v.Index(n) - // Cannot be invalid. - case reflect.Map: - // The segment must be a key in the map. - v = v.MapIndex(reflect.ValueOf(seg)) - if !v.IsValid() { - return nil, fmt.Errorf("no key %q in map", seg) - } - default: - return nil, fmt.Errorf("value %s (%s) is not a schema, slice or map", v, v.Type()) - } - } - if s, ok := v.Interface().(*Schema); ok { - return s, nil - } - return nil, fmt.Errorf("does not refer to a schema, but to a %s", v.Type()) -} - -// lookupSchemaField returns the value of the field with the given name in v, -// or the zero value if there is no such field or it is not of type Schema or *Schema. -func lookupSchemaField(v reflect.Value, name string) reflect.Value { - if name == "type" { - // The "type" keyword may refer to Type or Types. - // At most one will be non-zero. - if t := v.FieldByName("Type"); !t.IsZero() { - return t - } - return v.FieldByName("Types") - } - if sf, ok := schemaFieldMap[name]; ok { - return v.FieldByIndex(sf.Index) - } - return reflect.Value{} -} diff --git a/jsonschema/json_pointer_test.go b/jsonschema/json_pointer_test.go deleted file mode 100644 index 54b84bed..00000000 --- a/jsonschema/json_pointer_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "strings" - "testing" -) - -func TestDereferenceJSONPointer(t *testing.T) { - s := &Schema{ - AllOf: []*Schema{{}, {}}, - Defs: map[string]*Schema{ - "": {Properties: map[string]*Schema{"": {}}}, - "A": {}, - "B": { - Defs: map[string]*Schema{ - "X": {}, - "Y": {}, - }, - }, - "/~": {}, - "~1": {}, - }, - } - - for _, tt := range []struct { - ptr string - want any - }{ - {"", s}, - {"/$defs/A", s.Defs["A"]}, - {"/$defs/B", s.Defs["B"]}, - {"/$defs/B/$defs/X", s.Defs["B"].Defs["X"]}, - {"/$defs//properties/", s.Defs[""].Properties[""]}, - {"/allOf/1", s.AllOf[1]}, - {"/$defs/~1~0", s.Defs["/~"]}, - {"/$defs/~01", s.Defs["~1"]}, - } { - got, err := dereferenceJSONPointer(s, tt.ptr) - if err != nil { - t.Fatal(err) - } - if got != tt.want { - t.Errorf("%s:\ngot %+v\nwant %+v", tt.ptr, got, tt.want) - } - } -} - -func TestDerefernceJSONPointerErrors(t *testing.T) { - s := &Schema{ - Type: "t", - Items: &Schema{}, - Required: []string{"a"}, - } - for _, tt := range []struct { - ptr string - want string // error must contain this string - }{ - {"x", "does not begin"}, // parse error: no initial '/' - {"/minItems", "does not refer to a schema"}, - {"/minItems/x", "navigated to nil"}, - {"/required/-", "not supported"}, - {"/required/01", "leading zeroes"}, - {"/required/x", "invalid int"}, - {"/required/1", "out of bounds"}, - {"/properties/x", "no key"}, - } { - _, err := dereferenceJSONPointer(s, tt.ptr) - if err == nil { - t.Errorf("%q: succeeded, want failure", tt.ptr) - } else if !strings.Contains(err.Error(), tt.want) { - t.Errorf("%q: error is %q, which does not contain %q", tt.ptr, err, tt.want) - } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/meta/applicator.json b/jsonschema/meta-schemas/draft2020-12/meta/applicator.json deleted file mode 100644 index f4775974..00000000 --- a/jsonschema/meta-schemas/draft2020-12/meta/applicator.json +++ /dev/null @@ -1,45 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/meta/applicator", - "$dynamicAnchor": "meta", - - "title": "Applicator vocabulary meta-schema", - "type": ["object", "boolean"], - "properties": { - "prefixItems": { "$ref": "#/$defs/schemaArray" }, - "items": { "$dynamicRef": "#meta" }, - "contains": { "$dynamicRef": "#meta" }, - "additionalProperties": { "$dynamicRef": "#meta" }, - "properties": { - "type": "object", - "additionalProperties": { "$dynamicRef": "#meta" }, - "default": {} - }, - "patternProperties": { - "type": "object", - "additionalProperties": { "$dynamicRef": "#meta" }, - "propertyNames": { "format": "regex" }, - "default": {} - }, - "dependentSchemas": { - "type": "object", - "additionalProperties": { "$dynamicRef": "#meta" }, - "default": {} - }, - "propertyNames": { "$dynamicRef": "#meta" }, - "if": { "$dynamicRef": "#meta" }, - "then": { "$dynamicRef": "#meta" }, - "else": { "$dynamicRef": "#meta" }, - "allOf": { "$ref": "#/$defs/schemaArray" }, - "anyOf": { "$ref": "#/$defs/schemaArray" }, - "oneOf": { "$ref": "#/$defs/schemaArray" }, - "not": { "$dynamicRef": "#meta" } - }, - "$defs": { - "schemaArray": { - "type": "array", - "minItems": 1, - "items": { "$dynamicRef": "#meta" } - } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/meta/content.json b/jsonschema/meta-schemas/draft2020-12/meta/content.json deleted file mode 100644 index 76e3760d..00000000 --- a/jsonschema/meta-schemas/draft2020-12/meta/content.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/meta/content", - "$dynamicAnchor": "meta", - - "title": "Content vocabulary meta-schema", - - "type": ["object", "boolean"], - "properties": { - "contentEncoding": { "type": "string" }, - "contentMediaType": { "type": "string" }, - "contentSchema": { "$dynamicRef": "#meta" } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/meta/core.json b/jsonschema/meta-schemas/draft2020-12/meta/core.json deleted file mode 100644 index 69186228..00000000 --- a/jsonschema/meta-schemas/draft2020-12/meta/core.json +++ /dev/null @@ -1,48 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/meta/core", - "$dynamicAnchor": "meta", - - "title": "Core vocabulary meta-schema", - "type": ["object", "boolean"], - "properties": { - "$id": { - "$ref": "#/$defs/uriReferenceString", - "$comment": "Non-empty fragments not allowed.", - "pattern": "^[^#]*#?$" - }, - "$schema": { "$ref": "#/$defs/uriString" }, - "$ref": { "$ref": "#/$defs/uriReferenceString" }, - "$anchor": { "$ref": "#/$defs/anchorString" }, - "$dynamicRef": { "$ref": "#/$defs/uriReferenceString" }, - "$dynamicAnchor": { "$ref": "#/$defs/anchorString" }, - "$vocabulary": { - "type": "object", - "propertyNames": { "$ref": "#/$defs/uriString" }, - "additionalProperties": { - "type": "boolean" - } - }, - "$comment": { - "type": "string" - }, - "$defs": { - "type": "object", - "additionalProperties": { "$dynamicRef": "#meta" } - } - }, - "$defs": { - "anchorString": { - "type": "string", - "pattern": "^[A-Za-z_][-A-Za-z0-9._]*$" - }, - "uriString": { - "type": "string", - "format": "uri" - }, - "uriReferenceString": { - "type": "string", - "format": "uri-reference" - } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/meta/format-annotation.json b/jsonschema/meta-schemas/draft2020-12/meta/format-annotation.json deleted file mode 100644 index 3479e669..00000000 --- a/jsonschema/meta-schemas/draft2020-12/meta/format-annotation.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/meta/format-annotation", - "$dynamicAnchor": "meta", - - "title": "Format vocabulary meta-schema for annotation results", - "type": ["object", "boolean"], - "properties": { - "format": { "type": "string" } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/meta/meta-data.json b/jsonschema/meta-schemas/draft2020-12/meta/meta-data.json deleted file mode 100644 index 4049ab21..00000000 --- a/jsonschema/meta-schemas/draft2020-12/meta/meta-data.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/meta/meta-data", - "$dynamicAnchor": "meta", - - "title": "Meta-data vocabulary meta-schema", - - "type": ["object", "boolean"], - "properties": { - "title": { - "type": "string" - }, - "description": { - "type": "string" - }, - "default": true, - "deprecated": { - "type": "boolean", - "default": false - }, - "readOnly": { - "type": "boolean", - "default": false - }, - "writeOnly": { - "type": "boolean", - "default": false - }, - "examples": { - "type": "array", - "items": true - } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/meta/unevaluated.json b/jsonschema/meta-schemas/draft2020-12/meta/unevaluated.json deleted file mode 100644 index 93779e54..00000000 --- a/jsonschema/meta-schemas/draft2020-12/meta/unevaluated.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/meta/unevaluated", - "$dynamicAnchor": "meta", - - "title": "Unevaluated applicator vocabulary meta-schema", - "type": ["object", "boolean"], - "properties": { - "unevaluatedItems": { "$dynamicRef": "#meta" }, - "unevaluatedProperties": { "$dynamicRef": "#meta" } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/meta/validation.json b/jsonschema/meta-schemas/draft2020-12/meta/validation.json deleted file mode 100644 index ebb75db7..00000000 --- a/jsonschema/meta-schemas/draft2020-12/meta/validation.json +++ /dev/null @@ -1,95 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/meta/validation", - "$dynamicAnchor": "meta", - - "title": "Validation vocabulary meta-schema", - "type": ["object", "boolean"], - "properties": { - "type": { - "anyOf": [ - { "$ref": "#/$defs/simpleTypes" }, - { - "type": "array", - "items": { "$ref": "#/$defs/simpleTypes" }, - "minItems": 1, - "uniqueItems": true - } - ] - }, - "const": true, - "enum": { - "type": "array", - "items": true - }, - "multipleOf": { - "type": "number", - "exclusiveMinimum": 0 - }, - "maximum": { - "type": "number" - }, - "exclusiveMaximum": { - "type": "number" - }, - "minimum": { - "type": "number" - }, - "exclusiveMinimum": { - "type": "number" - }, - "maxLength": { "$ref": "#/$defs/nonNegativeInteger" }, - "minLength": { "$ref": "#/$defs/nonNegativeIntegerDefault0" }, - "pattern": { - "type": "string", - "format": "regex" - }, - "maxItems": { "$ref": "#/$defs/nonNegativeInteger" }, - "minItems": { "$ref": "#/$defs/nonNegativeIntegerDefault0" }, - "uniqueItems": { - "type": "boolean", - "default": false - }, - "maxContains": { "$ref": "#/$defs/nonNegativeInteger" }, - "minContains": { - "$ref": "#/$defs/nonNegativeInteger", - "default": 1 - }, - "maxProperties": { "$ref": "#/$defs/nonNegativeInteger" }, - "minProperties": { "$ref": "#/$defs/nonNegativeIntegerDefault0" }, - "required": { "$ref": "#/$defs/stringArray" }, - "dependentRequired": { - "type": "object", - "additionalProperties": { - "$ref": "#/$defs/stringArray" - } - } - }, - "$defs": { - "nonNegativeInteger": { - "type": "integer", - "minimum": 0 - }, - "nonNegativeIntegerDefault0": { - "$ref": "#/$defs/nonNegativeInteger", - "default": 0 - }, - "simpleTypes": { - "enum": [ - "array", - "boolean", - "integer", - "null", - "number", - "object", - "string" - ] - }, - "stringArray": { - "type": "array", - "items": { "type": "string" }, - "uniqueItems": true, - "default": [] - } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/schema.json b/jsonschema/meta-schemas/draft2020-12/schema.json deleted file mode 100644 index d5e2d31c..00000000 --- a/jsonschema/meta-schemas/draft2020-12/schema.json +++ /dev/null @@ -1,58 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/schema", - "$vocabulary": { - "https://json-schema.org/draft/2020-12/vocab/core": true, - "https://json-schema.org/draft/2020-12/vocab/applicator": true, - "https://json-schema.org/draft/2020-12/vocab/unevaluated": true, - "https://json-schema.org/draft/2020-12/vocab/validation": true, - "https://json-schema.org/draft/2020-12/vocab/meta-data": true, - "https://json-schema.org/draft/2020-12/vocab/format-annotation": true, - "https://json-schema.org/draft/2020-12/vocab/content": true - }, - "$dynamicAnchor": "meta", - - "title": "Core and Validation specifications meta-schema", - "allOf": [ - {"$ref": "meta/core"}, - {"$ref": "meta/applicator"}, - {"$ref": "meta/unevaluated"}, - {"$ref": "meta/validation"}, - {"$ref": "meta/meta-data"}, - {"$ref": "meta/format-annotation"}, - {"$ref": "meta/content"} - ], - "type": ["object", "boolean"], - "$comment": "This meta-schema also defines keywords that have appeared in previous drafts in order to prevent incompatible extensions as they remain in common use.", - "properties": { - "definitions": { - "$comment": "\"definitions\" has been replaced by \"$defs\".", - "type": "object", - "additionalProperties": { "$dynamicRef": "#meta" }, - "deprecated": true, - "default": {} - }, - "dependencies": { - "$comment": "\"dependencies\" has been split and replaced by \"dependentSchemas\" and \"dependentRequired\" in order to serve their differing semantics.", - "type": "object", - "additionalProperties": { - "anyOf": [ - { "$dynamicRef": "#meta" }, - { "$ref": "meta/validation#/$defs/stringArray" } - ] - }, - "deprecated": true, - "default": {} - }, - "$recursiveAnchor": { - "$comment": "\"$recursiveAnchor\" has been replaced by \"$dynamicAnchor\".", - "$ref": "meta/core#/$defs/anchorString", - "deprecated": true - }, - "$recursiveRef": { - "$comment": "\"$recursiveRef\" has been replaced by \"$dynamicRef\".", - "$ref": "meta/core#/$defs/uriReferenceString", - "deprecated": true - } - } -} diff --git a/jsonschema/resolve.go b/jsonschema/resolve.go deleted file mode 100644 index cc551e79..00000000 --- a/jsonschema/resolve.go +++ /dev/null @@ -1,548 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -// This file deals with preparing a schema for validation, including various checks, -// optimizations, and the resolution of cross-schema references. - -package jsonschema - -import ( - "errors" - "fmt" - "net/url" - "reflect" - "regexp" - "strings" -) - -// A Resolved consists of a [Schema] along with associated information needed to -// validate documents against it. -// A Resolved has been validated against its meta-schema, and all its references -// (the $ref and $dynamicRef keywords) have been resolved to their referenced Schemas. -// Call [Schema.Resolve] to obtain a Resolved from a Schema. -type Resolved struct { - root *Schema - // map from $ids to their schemas - resolvedURIs map[string]*Schema - // map from schemas to additional info computed during resolution - resolvedInfos map[*Schema]*resolvedInfo -} - -func newResolved(s *Schema) *Resolved { - return &Resolved{ - root: s, - resolvedURIs: map[string]*Schema{}, - resolvedInfos: map[*Schema]*resolvedInfo{}, - } -} - -// resolvedInfo holds information specific to a schema that is computed by [Schema.Resolve]. -type resolvedInfo struct { - s *Schema - // The JSON Pointer path from the root schema to here. - // Used in errors. - path string - // The schema's base schema. - // If the schema is the root or has an ID, its base is itself. - // Otherwise, its base is the innermost enclosing schema whose base - // is itself. - // Intuitively, a base schema is one that can be referred to with a - // fragmentless URI. - base *Schema - // The URI for the schema, if it is the root or has an ID. - // Otherwise nil. - // Invariants: - // s.base.uri != nil. - // s.base == s <=> s.uri != nil - uri *url.URL - // The schema to which Ref refers. - resolvedRef *Schema - - // If the schema has a dynamic ref, exactly one of the next two fields - // will be non-zero after successful resolution. - // The schema to which the dynamic ref refers when it acts lexically. - resolvedDynamicRef *Schema - // The anchor to look up on the stack when the dynamic ref acts dynamically. - dynamicRefAnchor string - - // The following fields are independent of arguments to Schema.Resolved, - // so they could live on the Schema. We put them here for simplicity. - - // The set of required properties. - isRequired map[string]bool - - // Compiled regexps. - pattern *regexp.Regexp - patternProperties map[*regexp.Regexp]*Schema - - // Map from anchors to subschemas. - anchors map[string]anchorInfo -} - -// Schema returns the schema that was resolved. -// It must not be modified. -func (r *Resolved) Schema() *Schema { return r.root } - -// schemaString returns a short string describing the schema. -func (r *Resolved) schemaString(s *Schema) string { - if s.ID != "" { - return s.ID - } - info := r.resolvedInfos[s] - if info.path != "" { - return info.path - } - return "" -} - -// A Loader reads and unmarshals the schema at uri, if any. -type Loader func(uri *url.URL) (*Schema, error) - -// ResolveOptions are options for [Schema.Resolve]. -type ResolveOptions struct { - // BaseURI is the URI relative to which the root schema should be resolved. - // If non-empty, must be an absolute URI (one that starts with a scheme). - // It is resolved (in the URI sense; see [url.ResolveReference]) with root's - // $id property. - // If the resulting URI is not absolute, then the schema cannot contain - // relative URI references. - BaseURI string - // Loader loads schemas that are referred to by a $ref but are not under the - // root schema (remote references). - // If nil, resolving a remote reference will return an error. - Loader Loader - // ValidateDefaults determines whether to validate values of "default" keywords - // against their schemas. - // The [JSON Schema specification] does not require this, but it is - // recommended if defaults will be used. - // - // [JSON Schema specification]: https://json-schema.org/understanding-json-schema/reference/annotations - ValidateDefaults bool -} - -// Resolve resolves all references within the schema and performs other tasks that -// prepare the schema for validation. -// If opts is nil, the default values are used. -// The schema must not be changed after Resolve is called. -// The same schema may be resolved multiple times. -func (root *Schema) Resolve(opts *ResolveOptions) (*Resolved, error) { - // There are up to five steps required to prepare a schema to validate. - // 1. Load: read the schema from somewhere and unmarshal it. - // This schema (root) may have been loaded or created in memory, but other schemas that - // come into the picture in step 4 will be loaded by the given loader. - // 2. Check: validate the schema against a meta-schema, and perform other well-formedness checks. - // Precompute some values along the way. - // 3. Resolve URIs: determine the base URI of the root and all its subschemas, and - // resolve (in the URI sense) all identifiers and anchors with their bases. This step results - // in a map from URIs to schemas within root. - // 4. Resolve references: all refs in the schemas are replaced with the schema they refer to. - // 5. (Optional.) If opts.ValidateDefaults is true, validate the defaults. - r := &resolver{loaded: map[string]*Resolved{}} - if opts != nil { - r.opts = *opts - } - var base *url.URL - if r.opts.BaseURI == "" { - base = &url.URL{} // so we can call ResolveReference on it - } else { - var err error - base, err = url.Parse(r.opts.BaseURI) - if err != nil { - return nil, fmt.Errorf("parsing base URI: %w", err) - } - } - - if r.opts.Loader == nil { - r.opts.Loader = func(uri *url.URL) (*Schema, error) { - return nil, errors.New("cannot resolve remote schemas: no loader passed to Schema.Resolve") - } - } - - resolved, err := r.resolve(root, base) - if err != nil { - return nil, err - } - if r.opts.ValidateDefaults { - if err := resolved.validateDefaults(); err != nil { - return nil, err - } - } - // TODO: before we return, throw away anything we don't need for validation. - return resolved, nil -} - -// A resolver holds the state for resolution. -type resolver struct { - opts ResolveOptions - // A cache of loaded and partly resolved schemas. (They may not have had their - // refs resolved.) The cache ensures that the loader will never be called more - // than once with the same URI, and that reference cycles are handled properly. - loaded map[string]*Resolved -} - -func (r *resolver) resolve(s *Schema, baseURI *url.URL) (*Resolved, error) { - if baseURI.Fragment != "" { - return nil, fmt.Errorf("base URI %s must not have a fragment", baseURI) - } - rs := newResolved(s) - - if err := s.check(rs.resolvedInfos); err != nil { - return nil, err - } - - if err := resolveURIs(rs, baseURI); err != nil { - return nil, err - } - - // Remember the schema by both the URI we loaded it from and its canonical name, - // which may differ if the schema has an $id. - // We must set the map before calling resolveRefs, or ref cycles will cause unbounded recursion. - r.loaded[baseURI.String()] = rs - r.loaded[rs.resolvedInfos[s].uri.String()] = rs - - if err := r.resolveRefs(rs); err != nil { - return nil, err - } - return rs, nil -} - -func (root *Schema) check(infos map[*Schema]*resolvedInfo) error { - // Check for structural validity. Do this first and fail fast: - // bad structure will cause other code to panic. - if err := root.checkStructure(infos); err != nil { - return err - } - - var errs []error - report := func(err error) { errs = append(errs, err) } - - for ss := range root.all() { - ss.checkLocal(report, infos) - } - return errors.Join(errs...) -} - -// checkStructure verifies that root and its subschemas form a tree. -// It also assigns each schema a unique path, to improve error messages. -func (root *Schema) checkStructure(infos map[*Schema]*resolvedInfo) error { - assert(len(infos) == 0, "non-empty infos") - - var check func(reflect.Value, []byte) error - check = func(v reflect.Value, path []byte) error { - // For the purpose of error messages, the root schema has path "root" - // and other schemas' paths are their JSON Pointer from the root. - p := "root" - if len(path) > 0 { - p = string(path) - } - s := v.Interface().(*Schema) - if s == nil { - return fmt.Errorf("jsonschema: schema at %s is nil", p) - } - if info, ok := infos[s]; ok { - // We've seen s before. - // The schema graph at root is not a tree, but it needs to - // be because a schema's base must be unique. - // A cycle would also put Schema.all into an infinite recursion. - return fmt.Errorf("jsonschema: schemas at %s do not form a tree; %s appears more than once (also at %s)", - root, info.path, p) - } - infos[s] = &resolvedInfo{s: s, path: p} - - for _, info := range schemaFieldInfos { - fv := v.Elem().FieldByIndex(info.sf.Index) - switch info.sf.Type { - case schemaType: - // A field that contains an individual schema. - // A nil is valid: it just means the field isn't present. - if !fv.IsNil() { - if err := check(fv, fmt.Appendf(path, "/%s", info.jsonName)); err != nil { - return err - } - } - - case schemaSliceType: - for i := range fv.Len() { - if err := check(fv.Index(i), fmt.Appendf(path, "/%s/%d", info.jsonName, i)); err != nil { - return err - } - } - - case schemaMapType: - iter := fv.MapRange() - for iter.Next() { - key := escapeJSONPointerSegment(iter.Key().String()) - if err := check(iter.Value(), fmt.Appendf(path, "/%s/%s", info.jsonName, key)); err != nil { - return err - } - } - } - - } - return nil - } - - return check(reflect.ValueOf(root), make([]byte, 0, 256)) -} - -// checkLocal checks s for validity, independently of other schemas it may refer to. -// Since checking a regexp involves compiling it, checkLocal saves those compiled regexps -// in the schema for later use. -// It appends the errors it finds to errs. -func (s *Schema) checkLocal(report func(error), infos map[*Schema]*resolvedInfo) { - addf := func(format string, args ...any) { - msg := fmt.Sprintf(format, args...) - report(fmt.Errorf("jsonschema.Schema: %s: %s", s, msg)) - } - - if s == nil { - addf("nil subschema") - return - } - if err := s.basicChecks(); err != nil { - report(err) - return - } - - // TODO: validate the schema's properties, - // ideally by jsonschema-validating it against the meta-schema. - - // Some properties are present so that Schemas can round-trip, but we do not - // validate them. - // Currently, it's just the $vocabulary property. - // As a special case, we can validate the 2020-12 meta-schema. - if s.Vocabulary != nil && s.Schema != draft202012 { - addf("cannot validate a schema with $vocabulary") - } - - info := infos[s] - - // Check and compile regexps. - if s.Pattern != "" { - re, err := regexp.Compile(s.Pattern) - if err != nil { - addf("pattern: %v", err) - } else { - info.pattern = re - } - } - if len(s.PatternProperties) > 0 { - info.patternProperties = map[*regexp.Regexp]*Schema{} - for reString, subschema := range s.PatternProperties { - re, err := regexp.Compile(reString) - if err != nil { - addf("patternProperties[%q]: %v", reString, err) - continue - } - info.patternProperties[re] = subschema - } - } - - // Build a set of required properties, to avoid quadratic behavior when validating - // a struct. - if len(s.Required) > 0 { - info.isRequired = map[string]bool{} - for _, r := range s.Required { - info.isRequired[r] = true - } - } -} - -// resolveURIs resolves the ids and anchors in all the schemas of root, relative -// to baseURI. -// See https://json-schema.org/draft/2020-12/json-schema-core#section-8.2, section -// 8.2.1. -// -// Every schema has a base URI and a parent base URI. -// -// The parent base URI is the base URI of the lexically enclosing schema, or for -// a root schema, the URI it was loaded from or the one supplied to [Schema.Resolve]. -// -// If the schema has no $id property, the base URI of a schema is that of its parent. -// If the schema does have an $id, it must be a URI, possibly relative. The schema's -// base URI is the $id resolved (in the sense of [url.URL.ResolveReference]) against -// the parent base. -// -// As an example, consider this schema loaded from http://a.com/root.json (quotes omitted): -// -// { -// allOf: [ -// {$id: "sub1.json", minLength: 5}, -// {$id: "http://b.com", minimum: 10}, -// {not: {maximum: 20}} -// ] -// } -// -// The base URIs are as follows. Schema locations are expressed in the JSON Pointer notation. -// -// schema base URI -// root http://a.com/root.json -// allOf/0 http://a.com/sub1.json -// allOf/1 http://b.com (absolute $id; doesn't matter that it's not under the loaded URI) -// allOf/2 http://a.com/root.json (inherited from parent) -// allOf/2/not http://a.com/root.json (inherited from parent) -func resolveURIs(rs *Resolved, baseURI *url.URL) error { - var resolve func(s, base *Schema) error - resolve = func(s, base *Schema) error { - info := rs.resolvedInfos[s] - baseInfo := rs.resolvedInfos[base] - - // ids are scoped to the root. - if s.ID != "" { - // A non-empty ID establishes a new base. - idURI, err := url.Parse(s.ID) - if err != nil { - return err - } - if idURI.Fragment != "" { - return fmt.Errorf("$id %s must not have a fragment", s.ID) - } - // The base URI for this schema is its $id resolved against the parent base. - info.uri = baseInfo.uri.ResolveReference(idURI) - if !info.uri.IsAbs() { - return fmt.Errorf("$id %s does not resolve to an absolute URI (base is %q)", s.ID, baseInfo.uri) - } - rs.resolvedURIs[info.uri.String()] = s - base = s // needed for anchors - baseInfo = rs.resolvedInfos[base] - } - info.base = base - - // Anchors and dynamic anchors are URI fragments that are scoped to their base. - // We treat them as keys in a map stored within the schema. - setAnchor := func(anchor string, dynamic bool) error { - if anchor != "" { - if _, ok := baseInfo.anchors[anchor]; ok { - return fmt.Errorf("duplicate anchor %q in %s", anchor, baseInfo.uri) - } - if baseInfo.anchors == nil { - baseInfo.anchors = map[string]anchorInfo{} - } - baseInfo.anchors[anchor] = anchorInfo{s, dynamic} - } - return nil - } - - setAnchor(s.Anchor, false) - setAnchor(s.DynamicAnchor, true) - - for c := range s.children() { - if err := resolve(c, base); err != nil { - return err - } - } - return nil - } - - // Set the root URI to the base for now. If the root has an $id, this will change. - rs.resolvedInfos[rs.root].uri = baseURI - // The original base, even if changed, is still a valid way to refer to the root. - rs.resolvedURIs[baseURI.String()] = rs.root - - return resolve(rs.root, rs.root) -} - -// resolveRefs replaces every ref in the schemas with the schema it refers to. -// A reference that doesn't resolve within the schema may refer to some other schema -// that needs to be loaded. -func (r *resolver) resolveRefs(rs *Resolved) error { - for s := range rs.root.all() { - info := rs.resolvedInfos[s] - if s.Ref != "" { - refSchema, _, err := r.resolveRef(rs, s, s.Ref) - if err != nil { - return err - } - // Whether or not the anchor referred to by $ref fragment is dynamic, - // the ref still treats it lexically. - info.resolvedRef = refSchema - } - if s.DynamicRef != "" { - refSchema, frag, err := r.resolveRef(rs, s, s.DynamicRef) - if err != nil { - return err - } - if frag != "" { - // The dynamic ref's fragment points to a dynamic anchor. - // We must resolve the fragment at validation time. - info.dynamicRefAnchor = frag - } else { - // There is no dynamic anchor in the lexically referenced schema, - // so the dynamic ref behaves like a lexical ref. - info.resolvedDynamicRef = refSchema - } - } - } - return nil -} - -// resolveRef resolves the reference ref, which is either s.Ref or s.DynamicRef. -func (r *resolver) resolveRef(rs *Resolved, s *Schema, ref string) (_ *Schema, dynamicFragment string, err error) { - refURI, err := url.Parse(ref) - if err != nil { - return nil, "", err - } - // URI-resolve the ref against the current base URI to get a complete URI. - base := rs.resolvedInfos[s].base - refURI = rs.resolvedInfos[base].uri.ResolveReference(refURI) - // The non-fragment part of a ref URI refers to the base URI of some schema. - // This part is the same for dynamic refs too: their non-fragment part resolves - // lexically. - u := *refURI - u.Fragment = "" - fraglessRefURI := &u - // Look it up locally. - referencedSchema := rs.resolvedURIs[fraglessRefURI.String()] - if referencedSchema == nil { - // The schema is remote. Maybe we've already loaded it. - // We assume that the non-fragment part of refURI refers to a top-level schema - // document. That is, we don't support the case exemplified by - // http://foo.com/bar.json/baz, where the document is in bar.json and - // the reference points to a subschema within it. - // TODO: support that case. - if lrs := r.loaded[fraglessRefURI.String()]; lrs != nil { - referencedSchema = lrs.root - } else { - // Try to load the schema. - ls, err := r.opts.Loader(fraglessRefURI) - if err != nil { - return nil, "", fmt.Errorf("loading %s: %w", fraglessRefURI, err) - } - lrs, err := r.resolve(ls, fraglessRefURI) - if err != nil { - return nil, "", err - } - referencedSchema = lrs.root - assert(referencedSchema != nil, "nil referenced schema") - // Copy the resolvedInfos from lrs into rs, without overwriting - // (hence we can't use maps.Insert). - for s, i := range lrs.resolvedInfos { - if rs.resolvedInfos[s] == nil { - rs.resolvedInfos[s] = i - } - } - } - } - - frag := refURI.Fragment - // Look up frag in refSchema. - // frag is either a JSON Pointer or the name of an anchor. - // A JSON Pointer is either the empty string or begins with a '/', - // whereas anchors are always non-empty strings that don't contain slashes. - if frag != "" && !strings.HasPrefix(frag, "/") { - resInfo := rs.resolvedInfos[referencedSchema] - info, found := resInfo.anchors[frag] - - if !found { - return nil, "", fmt.Errorf("no anchor %q in %s", frag, s) - } - if info.dynamic { - dynamicFragment = frag - } - return info.schema, dynamicFragment, nil - } - // frag is a JSON Pointer. - s, err = dereferenceJSONPointer(referencedSchema, frag) - return s, "", err -} diff --git a/jsonschema/resolve_test.go b/jsonschema/resolve_test.go deleted file mode 100644 index 36aa424b..00000000 --- a/jsonschema/resolve_test.go +++ /dev/null @@ -1,218 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "errors" - "maps" - "net/url" - "regexp" - "slices" - "strings" - "testing" -) - -func TestSchemaStructure(t *testing.T) { - check := func(s *Schema, want string) { - t.Helper() - infos := map[*Schema]*resolvedInfo{} - err := s.checkStructure(infos) - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("checkStructure returned error %q, want %q", err, want) - } - } - - dag := &Schema{Type: "number"} - dag = &Schema{Items: dag, Contains: dag} - check(dag, "do not form a tree") - - tree := &Schema{Type: "number"} - tree.Items = tree - check(tree, "do not form a tree") - - sliceNil := &Schema{PrefixItems: []*Schema{nil}} - check(sliceNil, "is nil") - - sliceMap := &Schema{Properties: map[string]*Schema{"a": nil}} - check(sliceMap, "is nil") -} - -func TestCheckLocal(t *testing.T) { - for _, tt := range []struct { - s *Schema - want string // error must be non-nil and match this regexp - }{ - { - &Schema{Pattern: "]["}, - "regexp", - }, - { - &Schema{PatternProperties: map[string]*Schema{"*": {}}}, - "regexp", - }, - } { - _, err := tt.s.Resolve(nil) - if err == nil { - t.Errorf("%s: unexpectedly passed", tt.s.json()) - continue - } - if !regexp.MustCompile(tt.want).MatchString(err.Error()) { - t.Errorf("checkLocal returned error\n%q\nwanted it to match\n%s\nregexp: %s", - tt.s.json(), err, tt.want) - } - } -} - -func TestPaths(t *testing.T) { - // CheckStructure should assign paths to schemas. - // This test also verifies that Schema.all visits maps in sorted order. - root := &Schema{ - Type: "string", - PrefixItems: []*Schema{{Type: "int"}, {Items: &Schema{Type: "null"}}}, - Contains: &Schema{Properties: map[string]*Schema{ - "~1": {Type: "boolean"}, - "p": {}, - }}, - } - - type item struct { - s *Schema - p string - } - want := []item{ - {root, "root"}, - {root.Contains, "/contains"}, - {root.Contains.Properties["p"], "/contains/properties/p"}, - {root.Contains.Properties["~1"], "/contains/properties/~01"}, - {root.PrefixItems[0], "/prefixItems/0"}, - {root.PrefixItems[1], "/prefixItems/1"}, - {root.PrefixItems[1].Items, "/prefixItems/1/items"}, - } - rs := newResolved(root) - if err := root.checkStructure(rs.resolvedInfos); err != nil { - t.Fatal(err) - } - - var got []item - for s := range root.all() { - got = append(got, item{s, rs.resolvedInfos[s].path}) - } - if !slices.Equal(got, want) { - t.Errorf("\ngot %v\nwant %v", got, want) - } -} - -func TestResolveURIs(t *testing.T) { - for _, baseURI := range []string{"", "http://a.com"} { - t.Run(baseURI, func(t *testing.T) { - root := &Schema{ - ID: "http://b.com", - Items: &Schema{ - ID: "/foo.json", - }, - Contains: &Schema{ - ID: "/bar.json", - Anchor: "a", - DynamicAnchor: "da", - Items: &Schema{ - Anchor: "b", - Items: &Schema{ - // An ID shouldn't be a query param, but this tests - // resolving an ID with its parent. - ID: "?items", - Anchor: "c", - }, - }, - }, - } - base, err := url.Parse(baseURI) - if err != nil { - t.Fatal(err) - } - - rs := newResolved(root) - if err := root.check(rs.resolvedInfos); err != nil { - t.Fatal(err) - } - if err := resolveURIs(rs, base); err != nil { - t.Fatal(err) - } - - wantIDs := map[string]*Schema{ - baseURI: root, - "http://b.com/foo.json": root.Items, - "http://b.com/bar.json": root.Contains, - "http://b.com/bar.json?items": root.Contains.Items.Items, - } - if baseURI != root.ID { - wantIDs[root.ID] = root - } - wantAnchors := map[*Schema]map[string]anchorInfo{ - root.Contains: { - "a": anchorInfo{root.Contains, false}, - "da": anchorInfo{root.Contains, true}, - "b": anchorInfo{root.Contains.Items, false}, - }, - root.Contains.Items.Items: { - "c": anchorInfo{root.Contains.Items.Items, false}, - }, - } - - got := rs.resolvedURIs - gotKeys := slices.Sorted(maps.Keys(got)) - wantKeys := slices.Sorted(maps.Keys(wantIDs)) - if !slices.Equal(gotKeys, wantKeys) { - t.Errorf("ID keys:\ngot %q\nwant %q", gotKeys, wantKeys) - } - if !maps.Equal(got, wantIDs) { - t.Errorf("IDs:\ngot %+v\n\nwant %+v", got, wantIDs) - } - for s := range root.all() { - info := rs.resolvedInfos[s] - if want := wantAnchors[s]; want != nil { - if got := info.anchors; !maps.Equal(got, want) { - t.Errorf("anchors:\ngot %+v\n\nwant %+v", got, want) - } - } else if info.anchors != nil { - t.Errorf("non-nil anchors for %s", s) - } - } - }) - } -} - -func TestRefCycle(t *testing.T) { - // Verify that cycles of refs are OK. - // The test suite doesn't check this, surprisingly. - schemas := map[string]*Schema{ - "root": {Ref: "a"}, - "a": {Ref: "b"}, - "b": {Ref: "a"}, - } - - loader := func(uri *url.URL) (*Schema, error) { - s, ok := schemas[uri.Path[1:]] - if !ok { - return nil, errors.New("not found") - } - return s, nil - } - - rs, err := schemas["root"].Resolve(&ResolveOptions{Loader: loader}) - if err != nil { - t.Fatal(err) - } - - check := func(s *Schema, key string) { - t.Helper() - if rs.resolvedInfos[s].resolvedRef != schemas[key] { - t.Errorf("%s resolvedRef != schemas[%q]", s.json(), key) - } - } - - check(rs.root, "a") - check(schemas["a"], "b") - check(schemas["b"], "a") -} diff --git a/jsonschema/schema.go b/jsonschema/schema.go deleted file mode 100644 index 9a68cd5d..00000000 --- a/jsonschema/schema.go +++ /dev/null @@ -1,436 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "bytes" - "cmp" - "encoding/json" - "errors" - "fmt" - "iter" - "maps" - "math" - "reflect" - "slices" -) - -// A Schema is a JSON schema object. -// It corresponds to the 2020-12 draft, as described in https://json-schema.org/draft/2020-12, -// specifically: -// - https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-01 -// - https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01 -// -// A Schema value may have non-zero values for more than one field: -// all relevant non-zero fields are used for validation. -// There is one exception to provide more Go type-safety: the Type and Types fields -// are mutually exclusive. -// -// Since this struct is a Go representation of a JSON value, it inherits JSON's -// distinction between nil and empty. Nil slices and maps are considered absent, -// but empty ones are present and affect validation. For example, -// -// Schema{Enum: nil} -// -// is equivalent to an empty schema, so it validates every instance. But -// -// Schema{Enum: []any{}} -// -// requires equality to some slice element, so it vacuously rejects every instance. -type Schema struct { - // core - ID string `json:"$id,omitempty"` - Schema string `json:"$schema,omitempty"` - Ref string `json:"$ref,omitempty"` - Comment string `json:"$comment,omitempty"` - Defs map[string]*Schema `json:"$defs,omitempty"` - // definitions is deprecated but still allowed. It is a synonym for $defs. - Definitions map[string]*Schema `json:"definitions,omitempty"` - - Anchor string `json:"$anchor,omitempty"` - DynamicAnchor string `json:"$dynamicAnchor,omitempty"` - DynamicRef string `json:"$dynamicRef,omitempty"` - Vocabulary map[string]bool `json:"$vocabulary,omitempty"` - - // metadata - Title string `json:"title,omitempty"` - Description string `json:"description,omitempty"` - Default json.RawMessage `json:"default,omitempty"` - Deprecated bool `json:"deprecated,omitempty"` - ReadOnly bool `json:"readOnly,omitempty"` - WriteOnly bool `json:"writeOnly,omitempty"` - Examples []any `json:"examples,omitempty"` - - // validation - // Use Type for a single type, or Types for multiple types; never both. - Type string `json:"-"` - Types []string `json:"-"` - Enum []any `json:"enum,omitempty"` - // Const is *any because a JSON null (Go nil) is a valid value. - Const *any `json:"const,omitempty"` - MultipleOf *float64 `json:"multipleOf,omitempty"` - Minimum *float64 `json:"minimum,omitempty"` - Maximum *float64 `json:"maximum,omitempty"` - ExclusiveMinimum *float64 `json:"exclusiveMinimum,omitempty"` - ExclusiveMaximum *float64 `json:"exclusiveMaximum,omitempty"` - MinLength *int `json:"minLength,omitempty"` - MaxLength *int `json:"maxLength,omitempty"` - Pattern string `json:"pattern,omitempty"` - - // arrays - PrefixItems []*Schema `json:"prefixItems,omitempty"` - Items *Schema `json:"items,omitempty"` - MinItems *int `json:"minItems,omitempty"` - MaxItems *int `json:"maxItems,omitempty"` - AdditionalItems *Schema `json:"additionalItems,omitempty"` - UniqueItems bool `json:"uniqueItems,omitempty"` - Contains *Schema `json:"contains,omitempty"` - MinContains *int `json:"minContains,omitempty"` // *int, not int: default is 1, not 0 - MaxContains *int `json:"maxContains,omitempty"` - UnevaluatedItems *Schema `json:"unevaluatedItems,omitempty"` - - // objects - MinProperties *int `json:"minProperties,omitempty"` - MaxProperties *int `json:"maxProperties,omitempty"` - Required []string `json:"required,omitempty"` - DependentRequired map[string][]string `json:"dependentRequired,omitempty"` - Properties map[string]*Schema `json:"properties,omitempty"` - PatternProperties map[string]*Schema `json:"patternProperties,omitempty"` - AdditionalProperties *Schema `json:"additionalProperties,omitempty"` - PropertyNames *Schema `json:"propertyNames,omitempty"` - UnevaluatedProperties *Schema `json:"unevaluatedProperties,omitempty"` - - // logic - AllOf []*Schema `json:"allOf,omitempty"` - AnyOf []*Schema `json:"anyOf,omitempty"` - OneOf []*Schema `json:"oneOf,omitempty"` - Not *Schema `json:"not,omitempty"` - - // conditional - If *Schema `json:"if,omitempty"` - Then *Schema `json:"then,omitempty"` - Else *Schema `json:"else,omitempty"` - DependentSchemas map[string]*Schema `json:"dependentSchemas,omitempty"` - - // other - // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.8 - ContentEncoding string `json:"contentEncoding,omitempty"` - ContentMediaType string `json:"contentMediaType,omitempty"` - ContentSchema *Schema `json:"contentSchema,omitempty"` - - // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.7 - Format string `json:"format,omitempty"` - - // Extra allows for additional keywords beyond those specified. - Extra map[string]any `json:"-"` -} - -// falseSchema returns a new Schema tree that fails to validate any value. -func falseSchema() *Schema { - return &Schema{Not: &Schema{}} -} - -// anchorInfo records the subschema to which an anchor refers, and whether -// the anchor keyword is $anchor or $dynamicAnchor. -type anchorInfo struct { - schema *Schema - dynamic bool -} - -// String returns a short description of the schema. -func (s *Schema) String() string { - if s.ID != "" { - return s.ID - } - if a := cmp.Or(s.Anchor, s.DynamicAnchor); a != "" { - return fmt.Sprintf("anchor %s", a) - } - return "" -} - -// CloneSchemas returns a copy of s. -// The copy is shallow except for sub-schemas, which are themelves copied with CloneSchemas. -// This allows both s and s.CloneSchemas() to appear as sub-schemas in the same parent. -func (s *Schema) CloneSchemas() *Schema { - if s == nil { - return nil - } - s2 := *s - v := reflect.ValueOf(&s2) - for _, info := range schemaFieldInfos { - fv := v.Elem().FieldByIndex(info.sf.Index) - switch info.sf.Type { - case schemaType: - sscss := fv.Interface().(*Schema) - fv.Set(reflect.ValueOf(sscss.CloneSchemas())) - - case schemaSliceType: - slice := fv.Interface().([]*Schema) - slice = slices.Clone(slice) - for i, ss := range slice { - slice[i] = ss.CloneSchemas() - } - fv.Set(reflect.ValueOf(slice)) - - case schemaMapType: - m := fv.Interface().(map[string]*Schema) - m = maps.Clone(m) - for k, ss := range m { - m[k] = ss.CloneSchemas() - } - fv.Set(reflect.ValueOf(m)) - } - } - return &s2 -} - -func (s *Schema) basicChecks() error { - if s.Type != "" && s.Types != nil { - return errors.New("both Type and Types are set; at most one should be") - } - if s.Defs != nil && s.Definitions != nil { - return errors.New("both Defs and Definitions are set; at most one should be") - } - return nil -} - -type schemaWithoutMethods Schema // doesn't implement json.{Unm,M}arshaler - -func (s *Schema) MarshalJSON() ([]byte, error) { - if err := s.basicChecks(); err != nil { - return nil, err - } - - // Marshal either Type or Types as "type". - var typ any - switch { - case s.Type != "": - typ = s.Type - case s.Types != nil: - typ = s.Types - } - ms := struct { - Type any `json:"type,omitempty"` - *schemaWithoutMethods - }{ - Type: typ, - schemaWithoutMethods: (*schemaWithoutMethods)(s), - } - bs, err := marshalStructWithMap(&ms, "Extra") - if err != nil { - return nil, err - } - // Marshal {} as true and {"not": {}} as false. - // It is wasteful to do this here instead of earlier, but much easier. - switch { - case bytes.Equal(bs, []byte(`{}`)): - bs = []byte("true") - case bytes.Equal(bs, []byte(`{"not":true}`)): - bs = []byte("false") - } - return bs, nil -} - -func (s *Schema) UnmarshalJSON(data []byte) error { - // A JSON boolean is a valid schema. - var b bool - if err := json.Unmarshal(data, &b); err == nil { - if b { - // true is the empty schema, which validates everything. - *s = Schema{} - } else { - // false is the schema that validates nothing. - *s = *falseSchema() - } - return nil - } - - ms := struct { - Type json.RawMessage `json:"type,omitempty"` - Const json.RawMessage `json:"const,omitempty"` - MinLength *integer `json:"minLength,omitempty"` - MaxLength *integer `json:"maxLength,omitempty"` - MinItems *integer `json:"minItems,omitempty"` - MaxItems *integer `json:"maxItems,omitempty"` - MinProperties *integer `json:"minProperties,omitempty"` - MaxProperties *integer `json:"maxProperties,omitempty"` - MinContains *integer `json:"minContains,omitempty"` - MaxContains *integer `json:"maxContains,omitempty"` - - *schemaWithoutMethods - }{ - schemaWithoutMethods: (*schemaWithoutMethods)(s), - } - if err := unmarshalStructWithMap(data, &ms, "Extra"); err != nil { - return err - } - // Unmarshal "type" as either Type or Types. - var err error - if len(ms.Type) > 0 { - switch ms.Type[0] { - case '"': - err = json.Unmarshal(ms.Type, &s.Type) - case '[': - err = json.Unmarshal(ms.Type, &s.Types) - default: - err = fmt.Errorf(`invalid value for "type": %q`, ms.Type) - } - } - if err != nil { - return err - } - - unmarshalAnyPtr := func(p **any, raw json.RawMessage) error { - if len(raw) == 0 { - return nil - } - if bytes.Equal(raw, []byte("null")) { - *p = new(any) - return nil - } - return json.Unmarshal(raw, p) - } - - // Setting Const to a pointer to null will marshal properly, but won't - // unmarshal: the *any is set to nil, not a pointer to nil. - if err := unmarshalAnyPtr(&s.Const, ms.Const); err != nil { - return err - } - - set := func(dst **int, src *integer) { - if src != nil { - *dst = Ptr(int(*src)) - } - } - - set(&s.MinLength, ms.MinLength) - set(&s.MaxLength, ms.MaxLength) - set(&s.MinItems, ms.MinItems) - set(&s.MaxItems, ms.MaxItems) - set(&s.MinProperties, ms.MinProperties) - set(&s.MaxProperties, ms.MaxProperties) - set(&s.MinContains, ms.MinContains) - set(&s.MaxContains, ms.MaxContains) - - return nil -} - -type integer int32 // for the integer-valued fields of Schema - -func (ip *integer) UnmarshalJSON(data []byte) error { - if len(data) == 0 { - // nothing to do - return nil - } - // If there is a decimal point, src is a floating-point number. - var i int64 - if bytes.ContainsRune(data, '.') { - var f float64 - if err := json.Unmarshal(data, &f); err != nil { - return errors.New("not a number") - } - i = int64(f) - if float64(i) != f { - return errors.New("not an integer value") - } - } else { - if err := json.Unmarshal(data, &i); err != nil { - return errors.New("cannot be unmarshaled into an int") - } - } - // Ensure behavior is the same on both 32-bit and 64-bit systems. - if i < math.MinInt32 || i > math.MaxInt32 { - return errors.New("integer is out of range") - } - *ip = integer(i) - return nil -} - -// Ptr returns a pointer to a new variable whose value is x. -func Ptr[T any](x T) *T { return &x } - -// every applies f preorder to every schema under s including s. -// The second argument to f is the path to the schema appended to the argument path. -// It stops when f returns false. -func (s *Schema) every(f func(*Schema) bool) bool { - return f(s) && s.everyChild(func(s *Schema) bool { return s.every(f) }) -} - -// everyChild reports whether f is true for every immediate child schema of s. -func (s *Schema) everyChild(f func(*Schema) bool) bool { - v := reflect.ValueOf(s) - for _, info := range schemaFieldInfos { - fv := v.Elem().FieldByIndex(info.sf.Index) - switch info.sf.Type { - case schemaType: - // A field that contains an individual schema. A nil is valid: it just means the field isn't present. - c := fv.Interface().(*Schema) - if c != nil && !f(c) { - return false - } - - case schemaSliceType: - slice := fv.Interface().([]*Schema) - for _, c := range slice { - if !f(c) { - return false - } - } - - case schemaMapType: - // Sort keys for determinism. - m := fv.Interface().(map[string]*Schema) - for _, k := range slices.Sorted(maps.Keys(m)) { - if !f(m[k]) { - return false - } - } - } - } - return true -} - -// all wraps every in an iterator. -func (s *Schema) all() iter.Seq[*Schema] { - return func(yield func(*Schema) bool) { s.every(yield) } -} - -// children wraps everyChild in an iterator. -func (s *Schema) children() iter.Seq[*Schema] { - return func(yield func(*Schema) bool) { s.everyChild(yield) } -} - -var ( - schemaType = reflect.TypeFor[*Schema]() - schemaSliceType = reflect.TypeFor[[]*Schema]() - schemaMapType = reflect.TypeFor[map[string]*Schema]() -) - -type structFieldInfo struct { - sf reflect.StructField - jsonName string -} - -var ( - // the visible fields of Schema that have a JSON name, sorted by that name - schemaFieldInfos []structFieldInfo - // map from JSON name to field - schemaFieldMap = map[string]reflect.StructField{} -) - -func init() { - for _, sf := range reflect.VisibleFields(reflect.TypeFor[Schema]()) { - info := fieldJSONInfo(sf) - if !info.omit { - schemaFieldInfos = append(schemaFieldInfos, structFieldInfo{sf, info.name}) - } - } - slices.SortFunc(schemaFieldInfos, func(i1, i2 structFieldInfo) int { - return cmp.Compare(i1.jsonName, i2.jsonName) - }) - for _, info := range schemaFieldInfos { - schemaFieldMap[info.jsonName] = info.sf - } -} diff --git a/jsonschema/schema_test.go b/jsonschema/schema_test.go deleted file mode 100644 index 19f6c6c7..00000000 --- a/jsonschema/schema_test.go +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "encoding/json" - "fmt" - "math" - "regexp" - "testing" -) - -func TestGoRoundTrip(t *testing.T) { - // Verify that Go representations round-trip. - for _, s := range []*Schema{ - {Type: "null"}, - {Types: []string{"null", "number"}}, - {Type: "string", MinLength: Ptr(20)}, - {Minimum: Ptr(20.0)}, - {Items: &Schema{Type: "integer"}}, - {Const: Ptr(any(0))}, - {Const: Ptr(any(nil))}, - {Const: Ptr(any([]int{}))}, - {Const: Ptr(any(map[string]any{}))}, - {Default: mustMarshal(1)}, - {Default: mustMarshal(nil)}, - {Extra: map[string]any{"test": "value"}}, - } { - data, err := json.Marshal(s) - if err != nil { - t.Fatal(err) - } - var got *Schema - mustUnmarshal(t, data, &got) - if !Equal(got, s) { - t.Errorf("got %s, want %s", got.json(), s.json()) - if got.Const != nil && s.Const != nil { - t.Logf("Consts: got %#v (%[1]T), want %#v (%[2]T)", *got.Const, *s.Const) - } - } - } -} - -func TestJSONRoundTrip(t *testing.T) { - // Verify that JSON texts for schemas marshal into equivalent forms. - // We don't expect everything to round-trip perfectly. For example, "true" and "false" - // will turn into their object equivalents. - // But most things should. - // Some of these cases test Schema.{UnM,M}arshalJSON. - // Most of others follow from the behavior of encoding/json, but they are still - // valuable as regression tests of this package's behavior. - for _, tt := range []struct { - in, want string - }{ - {`true`, `true`}, - {`false`, `false`}, - {`{"type":"", "enum":null}`, `true`}, // empty fields are omitted - {`{"minimum":1}`, `{"minimum":1}`}, - {`{"minimum":1.0}`, `{"minimum":1}`}, // floating-point integers lose their fractional part - {`{"minLength":1.0}`, `{"minLength":1}`}, // some floats are unmarshaled into ints, but you can't tell - { - // map keys are sorted - `{"$vocabulary":{"b":true, "a":false}}`, - `{"$vocabulary":{"a":false,"b":true}}`, - }, - {`{"unk":0}`, `{"unk":0}`}, // unknown fields are not dropped - { - // known and unknown fields are not dropped - // note that the order will be by the declaration order in the anonymous struct inside MarshalJSON - `{"comment":"test","type":"example","unk":0}`, - `{"type":"example","comment":"test","unk":0}`, - }, - {`{"extra":0}`, `{"extra":0}`}, // extra is not a special keyword and should not be dropped - {`{"Extra":0}`, `{"Extra":0}`}, // Extra is not a special keyword and should not be dropped - } { - var s Schema - mustUnmarshal(t, []byte(tt.in), &s) - data, err := json.Marshal(&s) - if err != nil { - t.Fatal(err) - } - if got := string(data); got != tt.want { - t.Errorf("%s:\ngot %s\nwant %s", tt.in, got, tt.want) - } - } -} - -func TestUnmarshalErrors(t *testing.T) { - for _, tt := range []struct { - in string - want string // error must match this regexp - }{ - {`1`, "cannot unmarshal number"}, - {`{"type":1}`, `invalid value for "type"`}, - {`{"minLength":1.5}`, `not an integer value`}, - {`{"maxLength":1.5}`, `not an integer value`}, - {`{"minItems":1.5}`, `not an integer value`}, - {`{"maxItems":1.5}`, `not an integer value`}, - {`{"minProperties":1.5}`, `not an integer value`}, - {`{"maxProperties":1.5}`, `not an integer value`}, - {`{"minContains":1.5}`, `not an integer value`}, - {`{"maxContains":1.5}`, `not an integer value`}, - {fmt.Sprintf(`{"maxContains":%d}`, int64(math.MaxInt32+1)), `out of range`}, - {`{"minLength":9e99}`, `cannot be unmarshaled`}, - {`{"minLength":"1.5"}`, `not a number`}, - } { - var s Schema - err := json.Unmarshal([]byte(tt.in), &s) - if err == nil { - t.Fatalf("%s: no error but expected one", tt.in) - } - if !regexp.MustCompile(tt.want).MatchString(err.Error()) { - t.Errorf("%s: error %q does not match %q", tt.in, err, tt.want) - } - - } -} - -func mustUnmarshal(t *testing.T, data []byte, ptr any) { - t.Helper() - if err := json.Unmarshal(data, ptr); err != nil { - t.Fatal(err) - } -} - -// json returns the schema in json format. -func (s *Schema) json() string { - data, err := json.Marshal(s) - if err != nil { - return fmt.Sprintf("", err) - } - return string(data) -} - -// json returns the schema in json format, indented. -func (s *Schema) jsonIndent() string { - data, err := json.MarshalIndent(s, "", " ") - if err != nil { - return fmt.Sprintf("", err) - } - return string(data) -} - -func TestCloneSchemas(t *testing.T) { - ss1 := &Schema{Type: "string"} - ss2 := &Schema{Type: "integer"} - ss3 := &Schema{Type: "boolean"} - ss4 := &Schema{Type: "number"} - ss5 := &Schema{Contains: ss4} - - s1 := Schema{ - Contains: ss1, - PrefixItems: []*Schema{ss2, ss3}, - Properties: map[string]*Schema{"a": ss5}, - } - s2 := s1.CloneSchemas() - - // The clones should appear identical. - if g, w := s1.json(), s2.json(); g != w { - t.Errorf("\ngot %s\nwant %s", g, w) - } - // None of the schemas should overlap. - schemas1 := map[*Schema]bool{ss1: true, ss2: true, ss3: true, ss4: true, ss5: true} - for ss := range s2.all() { - if schemas1[ss] { - t.Errorf("uncloned schema %s", ss.json()) - } - } - // s1's original schemas should be intact. - if s1.Contains != ss1 || s1.PrefixItems[0] != ss2 || s1.PrefixItems[1] != ss3 || ss5.Contains != ss4 || s1.Properties["a"] != ss5 { - t.Errorf("s1 modified") - } -} diff --git a/jsonschema/testdata/draft2020-12/README.md b/jsonschema/testdata/draft2020-12/README.md deleted file mode 100644 index dbc397dd..00000000 --- a/jsonschema/testdata/draft2020-12/README.md +++ /dev/null @@ -1,15 +0,0 @@ -# JSON Schema test suite for 2020-12 - -These files were copied from -https://github.com/json-schema-org/JSON-Schema-Test-Suite/tree/83e866b46c9f9e7082fd51e83a61c5f2145a1ab7/tests/draft2020-12. - -The following files were omitted: - -content.json: it is not required to validate content fields -(https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.8.1). - -format.json: it is not required to validate format fields (https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.7.1). - -vocabulary.json: this package doesn't support explicit vocabularies, other than the 2020-12 draft. - -The "optional" directory: this package doesn't implement any optional features. diff --git a/jsonschema/testdata/draft2020-12/additionalProperties.json b/jsonschema/testdata/draft2020-12/additionalProperties.json deleted file mode 100644 index 9618575e..00000000 --- a/jsonschema/testdata/draft2020-12/additionalProperties.json +++ /dev/null @@ -1,219 +0,0 @@ -[ - { - "description": - "additionalProperties being false does not allow other properties", - "specification": [ { "core":"10.3.2.3", "quote": "The value of \"additionalProperties\" MUST be a valid JSON Schema. Boolean \"false\" forbids everything." } ], - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": {"foo": {}, "bar": {}}, - "patternProperties": { "^v": {} }, - "additionalProperties": false - }, - "tests": [ - { - "description": "no additional properties is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "an additional property is invalid", - "data": {"foo" : 1, "bar" : 2, "quux" : "boom"}, - "valid": false - }, - { - "description": "ignores arrays", - "data": [1, 2, 3], - "valid": true - }, - { - "description": "ignores strings", - "data": "foobarbaz", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - }, - { - "description": "patternProperties are not additional properties", - "data": {"foo":1, "vroom": 2}, - "valid": true - } - ] - }, - { - "description": "non-ASCII pattern with additionalProperties", - "specification": [ { "core":"10.3.2.3"} ], - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "patternProperties": {"^á": {}}, - "additionalProperties": false - }, - "tests": [ - { - "description": "matching the pattern is valid", - "data": {"ármányos": 2}, - "valid": true - }, - { - "description": "not matching the pattern is invalid", - "data": {"élmény": 2}, - "valid": false - } - ] - }, - { - "description": "additionalProperties with schema", - "specification": [ { "core":"10.3.2.3", "quote": "The value of \"additionalProperties\" MUST be a valid JSON Schema." } ], - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": {"foo": {}, "bar": {}}, - "additionalProperties": {"type": "boolean"} - }, - "tests": [ - { - "description": "no additional properties is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "an additional valid property is valid", - "data": {"foo" : 1, "bar" : 2, "quux" : true}, - "valid": true - }, - { - "description": "an additional invalid property is invalid", - "data": {"foo" : 1, "bar" : 2, "quux" : 12}, - "valid": false - } - ] - }, - { - "description": "additionalProperties can exist by itself", - "specification": [ { "core":"10.3.2.3", "quote": "With no other applicator applying to object instances. This validates all the instance values irrespective of their property names" } ], - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "additionalProperties": {"type": "boolean"} - }, - "tests": [ - { - "description": "an additional valid property is valid", - "data": {"foo" : true}, - "valid": true - }, - { - "description": "an additional invalid property is invalid", - "data": {"foo" : 1}, - "valid": false - } - ] - }, - { - "description": "additionalProperties are allowed by default", - "specification": [ { "core":"10.3.2.3", "quote": "Omitting this keyword has the same assertion behavior as an empty schema." } ], - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": {"foo": {}, "bar": {}} - }, - "tests": [ - { - "description": "additional properties are allowed", - "data": {"foo": 1, "bar": 2, "quux": true}, - "valid": true - } - ] - }, - { - "description": "additionalProperties does not look in applicators", - "specification":[ { "core": "10.2", "quote": "Subschemas of applicator keywords evaluate the instance completely independently such that the results of one such subschema MUST NOT impact the results of sibling subschemas." } ], - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - {"properties": {"foo": {}}} - ], - "additionalProperties": {"type": "boolean"} - }, - "tests": [ - { - "description": "properties defined in allOf are not examined", - "data": {"foo": 1, "bar": true}, - "valid": false - } - ] - }, - { - "description": "additionalProperties with null valued instance properties", - "specification": [ { "core":"10.3.2.3" } ], - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "additionalProperties": { - "type": "null" - } - }, - "tests": [ - { - "description": "allows null values", - "data": {"foo": null}, - "valid": true - } - ] - }, - { - "description": "additionalProperties with propertyNames", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": { - "maxLength": 5 - }, - "additionalProperties": { - "type": "number" - } - }, - "tests": [ - { - "description": "Valid against both keywords", - "data": { "apple": 4 }, - "valid": true - }, - { - "description": "Valid against propertyNames, but not additionalProperties", - "data": { "fig": 2, "pear": "available" }, - "valid": false - } - ] - }, - { - "description": "dependentSchemas with additionalProperties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": {"foo2": {}}, - "dependentSchemas": { - "foo" : {}, - "foo2": { - "properties": { - "bar": {} - } - } - }, - "additionalProperties": false - }, - "tests": [ - { - "description": "additionalProperties doesn't consider dependentSchemas", - "data": {"foo": ""}, - "valid": false - }, - { - "description": "additionalProperties can't see bar", - "data": {"bar": ""}, - "valid": false - }, - { - "description": "additionalProperties can't see bar even when foo2 is present", - "data": {"foo2": "", "bar": ""}, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/allOf.json b/jsonschema/testdata/draft2020-12/allOf.json deleted file mode 100644 index 9e87903f..00000000 --- a/jsonschema/testdata/draft2020-12/allOf.json +++ /dev/null @@ -1,312 +0,0 @@ -[ - { - "description": "allOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "properties": { - "bar": {"type": "integer"} - }, - "required": ["bar"] - }, - { - "properties": { - "foo": {"type": "string"} - }, - "required": ["foo"] - } - ] - }, - "tests": [ - { - "description": "allOf", - "data": {"foo": "baz", "bar": 2}, - "valid": true - }, - { - "description": "mismatch second", - "data": {"foo": "baz"}, - "valid": false - }, - { - "description": "mismatch first", - "data": {"bar": 2}, - "valid": false - }, - { - "description": "wrong type", - "data": {"foo": "baz", "bar": "quux"}, - "valid": false - } - ] - }, - { - "description": "allOf with base schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": {"bar": {"type": "integer"}}, - "required": ["bar"], - "allOf" : [ - { - "properties": { - "foo": {"type": "string"} - }, - "required": ["foo"] - }, - { - "properties": { - "baz": {"type": "null"} - }, - "required": ["baz"] - } - ] - }, - "tests": [ - { - "description": "valid", - "data": {"foo": "quux", "bar": 2, "baz": null}, - "valid": true - }, - { - "description": "mismatch base schema", - "data": {"foo": "quux", "baz": null}, - "valid": false - }, - { - "description": "mismatch first allOf", - "data": {"bar": 2, "baz": null}, - "valid": false - }, - { - "description": "mismatch second allOf", - "data": {"foo": "quux", "bar": 2}, - "valid": false - }, - { - "description": "mismatch both", - "data": {"bar": 2}, - "valid": false - } - ] - }, - { - "description": "allOf simple types", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - {"maximum": 30}, - {"minimum": 20} - ] - }, - "tests": [ - { - "description": "valid", - "data": 25, - "valid": true - }, - { - "description": "mismatch one", - "data": 35, - "valid": false - } - ] - }, - { - "description": "allOf with boolean schemas, all true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [true, true] - }, - "tests": [ - { - "description": "any value is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "allOf with boolean schemas, some false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [true, false] - }, - "tests": [ - { - "description": "any value is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "allOf with boolean schemas, all false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [false, false] - }, - "tests": [ - { - "description": "any value is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "allOf with one empty schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - {} - ] - }, - "tests": [ - { - "description": "any data is valid", - "data": 1, - "valid": true - } - ] - }, - { - "description": "allOf with two empty schemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - {}, - {} - ] - }, - "tests": [ - { - "description": "any data is valid", - "data": 1, - "valid": true - } - ] - }, - { - "description": "allOf with the first empty schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - {}, - { "type": "number" } - ] - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "allOf with the last empty schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { "type": "number" }, - {} - ] - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "nested allOf, to check validation semantics", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "allOf": [ - { - "type": "null" - } - ] - } - ] - }, - "tests": [ - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "anything non-null is invalid", - "data": 123, - "valid": false - } - ] - }, - { - "description": "allOf combined with anyOf, oneOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ { "multipleOf": 2 } ], - "anyOf": [ { "multipleOf": 3 } ], - "oneOf": [ { "multipleOf": 5 } ] - }, - "tests": [ - { - "description": "allOf: false, anyOf: false, oneOf: false", - "data": 1, - "valid": false - }, - { - "description": "allOf: false, anyOf: false, oneOf: true", - "data": 5, - "valid": false - }, - { - "description": "allOf: false, anyOf: true, oneOf: false", - "data": 3, - "valid": false - }, - { - "description": "allOf: false, anyOf: true, oneOf: true", - "data": 15, - "valid": false - }, - { - "description": "allOf: true, anyOf: false, oneOf: false", - "data": 2, - "valid": false - }, - { - "description": "allOf: true, anyOf: false, oneOf: true", - "data": 10, - "valid": false - }, - { - "description": "allOf: true, anyOf: true, oneOf: false", - "data": 6, - "valid": false - }, - { - "description": "allOf: true, anyOf: true, oneOf: true", - "data": 30, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/anchor.json b/jsonschema/testdata/draft2020-12/anchor.json deleted file mode 100644 index 99143fa1..00000000 --- a/jsonschema/testdata/draft2020-12/anchor.json +++ /dev/null @@ -1,120 +0,0 @@ -[ - { - "description": "Location-independent identifier", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "#foo", - "$defs": { - "A": { - "$anchor": "foo", - "type": "integer" - } - } - }, - "tests": [ - { - "data": 1, - "description": "match", - "valid": true - }, - { - "data": "a", - "description": "mismatch", - "valid": false - } - ] - }, - { - "description": "Location-independent identifier with absolute URI", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/draft2020-12/bar#foo", - "$defs": { - "A": { - "$id": "http://localhost:1234/draft2020-12/bar", - "$anchor": "foo", - "type": "integer" - } - } - }, - "tests": [ - { - "data": 1, - "description": "match", - "valid": true - }, - { - "data": "a", - "description": "mismatch", - "valid": false - } - ] - }, - { - "description": "Location-independent identifier with base URI change in subschema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/root", - "$ref": "http://localhost:1234/draft2020-12/nested.json#foo", - "$defs": { - "A": { - "$id": "nested.json", - "$defs": { - "B": { - "$anchor": "foo", - "type": "integer" - } - } - } - } - }, - "tests": [ - { - "data": 1, - "description": "match", - "valid": true - }, - { - "data": "a", - "description": "mismatch", - "valid": false - } - ] - }, - { - "description": "same $anchor with different base uri", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/foobar", - "$defs": { - "A": { - "$id": "child1", - "allOf": [ - { - "$id": "child2", - "$anchor": "my_anchor", - "type": "number" - }, - { - "$anchor": "my_anchor", - "type": "string" - } - ] - } - }, - "$ref": "child1#my_anchor" - }, - "tests": [ - { - "description": "$ref resolves to /$defs/A/allOf/1", - "data": "a", - "valid": true - }, - { - "description": "$ref does not resolve to /$defs/A/allOf/0", - "data": 1, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/anyOf.json b/jsonschema/testdata/draft2020-12/anyOf.json deleted file mode 100644 index 89b192db..00000000 --- a/jsonschema/testdata/draft2020-12/anyOf.json +++ /dev/null @@ -1,203 +0,0 @@ -[ - { - "description": "anyOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "anyOf": [ - { - "type": "integer" - }, - { - "minimum": 2 - } - ] - }, - "tests": [ - { - "description": "first anyOf valid", - "data": 1, - "valid": true - }, - { - "description": "second anyOf valid", - "data": 2.5, - "valid": true - }, - { - "description": "both anyOf valid", - "data": 3, - "valid": true - }, - { - "description": "neither anyOf valid", - "data": 1.5, - "valid": false - } - ] - }, - { - "description": "anyOf with base schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "string", - "anyOf" : [ - { - "maxLength": 2 - }, - { - "minLength": 4 - } - ] - }, - "tests": [ - { - "description": "mismatch base schema", - "data": 3, - "valid": false - }, - { - "description": "one anyOf valid", - "data": "foobar", - "valid": true - }, - { - "description": "both anyOf invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "anyOf with boolean schemas, all true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "anyOf": [true, true] - }, - "tests": [ - { - "description": "any value is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "anyOf with boolean schemas, some true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "anyOf": [true, false] - }, - "tests": [ - { - "description": "any value is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "anyOf with boolean schemas, all false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "anyOf": [false, false] - }, - "tests": [ - { - "description": "any value is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "anyOf complex types", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "anyOf": [ - { - "properties": { - "bar": {"type": "integer"} - }, - "required": ["bar"] - }, - { - "properties": { - "foo": {"type": "string"} - }, - "required": ["foo"] - } - ] - }, - "tests": [ - { - "description": "first anyOf valid (complex)", - "data": {"bar": 2}, - "valid": true - }, - { - "description": "second anyOf valid (complex)", - "data": {"foo": "baz"}, - "valid": true - }, - { - "description": "both anyOf valid (complex)", - "data": {"foo": "baz", "bar": 2}, - "valid": true - }, - { - "description": "neither anyOf valid (complex)", - "data": {"foo": 2, "bar": "quux"}, - "valid": false - } - ] - }, - { - "description": "anyOf with one empty schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "anyOf": [ - { "type": "number" }, - {} - ] - }, - "tests": [ - { - "description": "string is valid", - "data": "foo", - "valid": true - }, - { - "description": "number is valid", - "data": 123, - "valid": true - } - ] - }, - { - "description": "nested anyOf, to check validation semantics", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "anyOf": [ - { - "anyOf": [ - { - "type": "null" - } - ] - } - ] - }, - "tests": [ - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "anything non-null is invalid", - "data": 123, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/boolean_schema.json b/jsonschema/testdata/draft2020-12/boolean_schema.json deleted file mode 100644 index 6d40f23f..00000000 --- a/jsonschema/testdata/draft2020-12/boolean_schema.json +++ /dev/null @@ -1,104 +0,0 @@ -[ - { - "description": "boolean schema 'true'", - "schema": true, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "string is valid", - "data": "foo", - "valid": true - }, - { - "description": "boolean true is valid", - "data": true, - "valid": true - }, - { - "description": "boolean false is valid", - "data": false, - "valid": true - }, - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "object is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - }, - { - "description": "array is valid", - "data": ["foo"], - "valid": true - }, - { - "description": "empty array is valid", - "data": [], - "valid": true - } - ] - }, - { - "description": "boolean schema 'false'", - "schema": false, - "tests": [ - { - "description": "number is invalid", - "data": 1, - "valid": false - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - }, - { - "description": "boolean true is invalid", - "data": true, - "valid": false - }, - { - "description": "boolean false is invalid", - "data": false, - "valid": false - }, - { - "description": "null is invalid", - "data": null, - "valid": false - }, - { - "description": "object is invalid", - "data": {"foo": "bar"}, - "valid": false - }, - { - "description": "empty object is invalid", - "data": {}, - "valid": false - }, - { - "description": "array is invalid", - "data": ["foo"], - "valid": false - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/const.json b/jsonschema/testdata/draft2020-12/const.json deleted file mode 100644 index 50be86a0..00000000 --- a/jsonschema/testdata/draft2020-12/const.json +++ /dev/null @@ -1,387 +0,0 @@ -[ - { - "description": "const validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": 2 - }, - "tests": [ - { - "description": "same value is valid", - "data": 2, - "valid": true - }, - { - "description": "another value is invalid", - "data": 5, - "valid": false - }, - { - "description": "another type is invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "const with object", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": {"foo": "bar", "baz": "bax"} - }, - "tests": [ - { - "description": "same object is valid", - "data": {"foo": "bar", "baz": "bax"}, - "valid": true - }, - { - "description": "same object with different property order is valid", - "data": {"baz": "bax", "foo": "bar"}, - "valid": true - }, - { - "description": "another object is invalid", - "data": {"foo": "bar"}, - "valid": false - }, - { - "description": "another type is invalid", - "data": [1, 2], - "valid": false - } - ] - }, - { - "description": "const with array", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": [{ "foo": "bar" }] - }, - "tests": [ - { - "description": "same array is valid", - "data": [{"foo": "bar"}], - "valid": true - }, - { - "description": "another array item is invalid", - "data": [2], - "valid": false - }, - { - "description": "array with additional items is invalid", - "data": [1, 2, 3], - "valid": false - } - ] - }, - { - "description": "const with null", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": null - }, - "tests": [ - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "not null is invalid", - "data": 0, - "valid": false - } - ] - }, - { - "description": "const with false does not match 0", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": false - }, - "tests": [ - { - "description": "false is valid", - "data": false, - "valid": true - }, - { - "description": "integer zero is invalid", - "data": 0, - "valid": false - }, - { - "description": "float zero is invalid", - "data": 0.0, - "valid": false - } - ] - }, - { - "description": "const with true does not match 1", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": true - }, - "tests": [ - { - "description": "true is valid", - "data": true, - "valid": true - }, - { - "description": "integer one is invalid", - "data": 1, - "valid": false - }, - { - "description": "float one is invalid", - "data": 1.0, - "valid": false - } - ] - }, - { - "description": "const with [false] does not match [0]", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": [false] - }, - "tests": [ - { - "description": "[false] is valid", - "data": [false], - "valid": true - }, - { - "description": "[0] is invalid", - "data": [0], - "valid": false - }, - { - "description": "[0.0] is invalid", - "data": [0.0], - "valid": false - } - ] - }, - { - "description": "const with [true] does not match [1]", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": [true] - }, - "tests": [ - { - "description": "[true] is valid", - "data": [true], - "valid": true - }, - { - "description": "[1] is invalid", - "data": [1], - "valid": false - }, - { - "description": "[1.0] is invalid", - "data": [1.0], - "valid": false - } - ] - }, - { - "description": "const with {\"a\": false} does not match {\"a\": 0}", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": {"a": false} - }, - "tests": [ - { - "description": "{\"a\": false} is valid", - "data": {"a": false}, - "valid": true - }, - { - "description": "{\"a\": 0} is invalid", - "data": {"a": 0}, - "valid": false - }, - { - "description": "{\"a\": 0.0} is invalid", - "data": {"a": 0.0}, - "valid": false - } - ] - }, - { - "description": "const with {\"a\": true} does not match {\"a\": 1}", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": {"a": true} - }, - "tests": [ - { - "description": "{\"a\": true} is valid", - "data": {"a": true}, - "valid": true - }, - { - "description": "{\"a\": 1} is invalid", - "data": {"a": 1}, - "valid": false - }, - { - "description": "{\"a\": 1.0} is invalid", - "data": {"a": 1.0}, - "valid": false - } - ] - }, - { - "description": "const with 0 does not match other zero-like types", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": 0 - }, - "tests": [ - { - "description": "false is invalid", - "data": false, - "valid": false - }, - { - "description": "integer zero is valid", - "data": 0, - "valid": true - }, - { - "description": "float zero is valid", - "data": 0.0, - "valid": true - }, - { - "description": "empty object is invalid", - "data": {}, - "valid": false - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - }, - { - "description": "empty string is invalid", - "data": "", - "valid": false - } - ] - }, - { - "description": "const with 1 does not match true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": 1 - }, - "tests": [ - { - "description": "true is invalid", - "data": true, - "valid": false - }, - { - "description": "integer one is valid", - "data": 1, - "valid": true - }, - { - "description": "float one is valid", - "data": 1.0, - "valid": true - } - ] - }, - { - "description": "const with -2.0 matches integer and float types", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": -2.0 - }, - "tests": [ - { - "description": "integer -2 is valid", - "data": -2, - "valid": true - }, - { - "description": "integer 2 is invalid", - "data": 2, - "valid": false - }, - { - "description": "float -2.0 is valid", - "data": -2.0, - "valid": true - }, - { - "description": "float 2.0 is invalid", - "data": 2.0, - "valid": false - }, - { - "description": "float -2.00001 is invalid", - "data": -2.00001, - "valid": false - } - ] - }, - { - "description": "float and integers are equal up to 64-bit representation limits", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": 9007199254740992 - }, - "tests": [ - { - "description": "integer is valid", - "data": 9007199254740992, - "valid": true - }, - { - "description": "integer minus one is invalid", - "data": 9007199254740991, - "valid": false - }, - { - "description": "float is valid", - "data": 9007199254740992.0, - "valid": true - }, - { - "description": "float minus one is invalid", - "data": 9007199254740991.0, - "valid": false - } - ] - }, - { - "description": "nul characters in strings", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": "hello\u0000there" - }, - "tests": [ - { - "description": "match string with nul", - "data": "hello\u0000there", - "valid": true - }, - { - "description": "do not match string lacking nul", - "data": "hellothere", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/contains.json b/jsonschema/testdata/draft2020-12/contains.json deleted file mode 100644 index 08a00a75..00000000 --- a/jsonschema/testdata/draft2020-12/contains.json +++ /dev/null @@ -1,176 +0,0 @@ -[ - { - "description": "contains keyword validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"minimum": 5} - }, - "tests": [ - { - "description": "array with item matching schema (5) is valid", - "data": [3, 4, 5], - "valid": true - }, - { - "description": "array with item matching schema (6) is valid", - "data": [3, 4, 6], - "valid": true - }, - { - "description": "array with two items matching schema (5, 6) is valid", - "data": [3, 4, 5, 6], - "valid": true - }, - { - "description": "array without items matching schema is invalid", - "data": [2, 3, 4], - "valid": false - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - }, - { - "description": "not array is valid", - "data": {}, - "valid": true - } - ] - }, - { - "description": "contains keyword with const keyword", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": { "const": 5 } - }, - "tests": [ - { - "description": "array with item 5 is valid", - "data": [3, 4, 5], - "valid": true - }, - { - "description": "array with two items 5 is valid", - "data": [3, 4, 5, 5], - "valid": true - }, - { - "description": "array without item 5 is invalid", - "data": [1, 2, 3, 4], - "valid": false - } - ] - }, - { - "description": "contains keyword with boolean schema true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": true - }, - "tests": [ - { - "description": "any non-empty array is valid", - "data": ["foo"], - "valid": true - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - } - ] - }, - { - "description": "contains keyword with boolean schema false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": false - }, - "tests": [ - { - "description": "any non-empty array is invalid", - "data": ["foo"], - "valid": false - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - }, - { - "description": "non-arrays are valid", - "data": "contains does not apply to strings", - "valid": true - } - ] - }, - { - "description": "items + contains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "items": { "multipleOf": 2 }, - "contains": { "multipleOf": 3 } - }, - "tests": [ - { - "description": "matches items, does not match contains", - "data": [ 2, 4, 8 ], - "valid": false - }, - { - "description": "does not match items, matches contains", - "data": [ 3, 6, 9 ], - "valid": false - }, - { - "description": "matches both items and contains", - "data": [ 6, 12 ], - "valid": true - }, - { - "description": "matches neither items nor contains", - "data": [ 1, 5 ], - "valid": false - } - ] - }, - { - "description": "contains with false if subschema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": { - "if": false, - "else": true - } - }, - "tests": [ - { - "description": "any non-empty array is valid", - "data": ["foo"], - "valid": true - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - } - ] - }, - { - "description": "contains with null instance elements", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": { - "type": "null" - } - }, - "tests": [ - { - "description": "allows null items", - "data": [ null ], - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/default.json b/jsonschema/testdata/draft2020-12/default.json deleted file mode 100644 index ceb3ae27..00000000 --- a/jsonschema/testdata/draft2020-12/default.json +++ /dev/null @@ -1,82 +0,0 @@ -[ - { - "description": "invalid type for default", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": { - "type": "integer", - "default": [] - } - } - }, - "tests": [ - { - "description": "valid when property is specified", - "data": {"foo": 13}, - "valid": true - }, - { - "description": "still valid when the invalid default is used", - "data": {}, - "valid": true - } - ] - }, - { - "description": "invalid string value for default", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "bar": { - "type": "string", - "minLength": 4, - "default": "bad" - } - } - }, - "tests": [ - { - "description": "valid when property is specified", - "data": {"bar": "good"}, - "valid": true - }, - { - "description": "still valid when the invalid default is used", - "data": {}, - "valid": true - } - ] - }, - { - "description": "the default keyword does not do anything if the property is missing", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "alpha": { - "type": "number", - "maximum": 3, - "default": 5 - } - } - }, - "tests": [ - { - "description": "an explicit property value is checked against maximum (passing)", - "data": { "alpha": 1 }, - "valid": true - }, - { - "description": "an explicit property value is checked against maximum (failing)", - "data": { "alpha": 5 }, - "valid": false - }, - { - "description": "missing properties are not filled in with the default", - "data": {}, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/defs.json b/jsonschema/testdata/draft2020-12/defs.json deleted file mode 100644 index da2a503b..00000000 --- a/jsonschema/testdata/draft2020-12/defs.json +++ /dev/null @@ -1,21 +0,0 @@ -[ - { - "description": "validate definition against metaschema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "https://json-schema.org/draft/2020-12/schema" - }, - "tests": [ - { - "description": "valid definition schema", - "data": {"$defs": {"foo": {"type": "integer"}}}, - "valid": true - }, - { - "description": "invalid definition schema", - "data": {"$defs": {"foo": {"type": 1}}}, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/dependentRequired.json b/jsonschema/testdata/draft2020-12/dependentRequired.json deleted file mode 100644 index 2baa38e9..00000000 --- a/jsonschema/testdata/draft2020-12/dependentRequired.json +++ /dev/null @@ -1,152 +0,0 @@ -[ - { - "description": "single dependency", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "dependentRequired": {"bar": ["foo"]} - }, - "tests": [ - { - "description": "neither", - "data": {}, - "valid": true - }, - { - "description": "nondependant", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "with dependency", - "data": {"foo": 1, "bar": 2}, - "valid": true - }, - { - "description": "missing dependency", - "data": {"bar": 2}, - "valid": false - }, - { - "description": "ignores arrays", - "data": ["bar"], - "valid": true - }, - { - "description": "ignores strings", - "data": "foobar", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": "empty dependents", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "dependentRequired": {"bar": []} - }, - "tests": [ - { - "description": "empty object", - "data": {}, - "valid": true - }, - { - "description": "object with one property", - "data": {"bar": 2}, - "valid": true - }, - { - "description": "non-object is valid", - "data": 1, - "valid": true - } - ] - }, - { - "description": "multiple dependents required", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "dependentRequired": {"quux": ["foo", "bar"]} - }, - "tests": [ - { - "description": "neither", - "data": {}, - "valid": true - }, - { - "description": "nondependants", - "data": {"foo": 1, "bar": 2}, - "valid": true - }, - { - "description": "with dependencies", - "data": {"foo": 1, "bar": 2, "quux": 3}, - "valid": true - }, - { - "description": "missing dependency", - "data": {"foo": 1, "quux": 2}, - "valid": false - }, - { - "description": "missing other dependency", - "data": {"bar": 1, "quux": 2}, - "valid": false - }, - { - "description": "missing both dependencies", - "data": {"quux": 1}, - "valid": false - } - ] - }, - { - "description": "dependencies with escaped characters", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "dependentRequired": { - "foo\nbar": ["foo\rbar"], - "foo\"bar": ["foo'bar"] - } - }, - "tests": [ - { - "description": "CRLF", - "data": { - "foo\nbar": 1, - "foo\rbar": 2 - }, - "valid": true - }, - { - "description": "quoted quotes", - "data": { - "foo'bar": 1, - "foo\"bar": 2 - }, - "valid": true - }, - { - "description": "CRLF missing dependent", - "data": { - "foo\nbar": 1, - "foo": 2 - }, - "valid": false - }, - { - "description": "quoted quotes missing dependent", - "data": { - "foo\"bar": 2 - }, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/dependentSchemas.json b/jsonschema/testdata/draft2020-12/dependentSchemas.json deleted file mode 100644 index 1c5f0574..00000000 --- a/jsonschema/testdata/draft2020-12/dependentSchemas.json +++ /dev/null @@ -1,171 +0,0 @@ -[ - { - "description": "single dependency", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "dependentSchemas": { - "bar": { - "properties": { - "foo": {"type": "integer"}, - "bar": {"type": "integer"} - } - } - } - }, - "tests": [ - { - "description": "valid", - "data": {"foo": 1, "bar": 2}, - "valid": true - }, - { - "description": "no dependency", - "data": {"foo": "quux"}, - "valid": true - }, - { - "description": "wrong type", - "data": {"foo": "quux", "bar": 2}, - "valid": false - }, - { - "description": "wrong type other", - "data": {"foo": 2, "bar": "quux"}, - "valid": false - }, - { - "description": "wrong type both", - "data": {"foo": "quux", "bar": "quux"}, - "valid": false - }, - { - "description": "ignores arrays", - "data": ["bar"], - "valid": true - }, - { - "description": "ignores strings", - "data": "foobar", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": "boolean subschemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "dependentSchemas": { - "foo": true, - "bar": false - } - }, - "tests": [ - { - "description": "object with property having schema true is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "object with property having schema false is invalid", - "data": {"bar": 2}, - "valid": false - }, - { - "description": "object with both properties is invalid", - "data": {"foo": 1, "bar": 2}, - "valid": false - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - } - ] - }, - { - "description": "dependencies with escaped characters", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "dependentSchemas": { - "foo\tbar": {"minProperties": 4}, - "foo'bar": {"required": ["foo\"bar"]} - } - }, - "tests": [ - { - "description": "quoted tab", - "data": { - "foo\tbar": 1, - "a": 2, - "b": 3, - "c": 4 - }, - "valid": true - }, - { - "description": "quoted quote", - "data": { - "foo'bar": {"foo\"bar": 1} - }, - "valid": false - }, - { - "description": "quoted tab invalid under dependent schema", - "data": { - "foo\tbar": 1, - "a": 2 - }, - "valid": false - }, - { - "description": "quoted quote invalid under dependent schema", - "data": {"foo'bar": 1}, - "valid": false - } - ] - }, - { - "description": "dependent subschema incompatible with root", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {} - }, - "dependentSchemas": { - "foo": { - "properties": { - "bar": {} - }, - "additionalProperties": false - } - } - }, - "tests": [ - { - "description": "matches root", - "data": {"foo": 1}, - "valid": false - }, - { - "description": "matches dependency", - "data": {"bar": 1}, - "valid": true - }, - { - "description": "matches both", - "data": {"foo": 1, "bar": 2}, - "valid": false - }, - { - "description": "no dependency", - "data": {"baz": 1}, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/dynamicRef.json b/jsonschema/testdata/draft2020-12/dynamicRef.json deleted file mode 100644 index ffa211ba..00000000 --- a/jsonschema/testdata/draft2020-12/dynamicRef.json +++ /dev/null @@ -1,815 +0,0 @@ -[ - { - "description": "A $dynamicRef to a $dynamicAnchor in the same schema resource behaves like a normal $ref to an $anchor", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamicRef-dynamicAnchor-same-schema/root", - "type": "array", - "items": { "$dynamicRef": "#items" }, - "$defs": { - "foo": { - "$dynamicAnchor": "items", - "type": "string" - } - } - }, - "tests": [ - { - "description": "An array of strings is valid", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "An array containing non-strings is invalid", - "data": ["foo", 42], - "valid": false - } - ] - }, - { - "description": "A $dynamicRef to an $anchor in the same schema resource behaves like a normal $ref to an $anchor", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamicRef-anchor-same-schema/root", - "type": "array", - "items": { "$dynamicRef": "#items" }, - "$defs": { - "foo": { - "$anchor": "items", - "type": "string" - } - } - }, - "tests": [ - { - "description": "An array of strings is valid", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "An array containing non-strings is invalid", - "data": ["foo", 42], - "valid": false - } - ] - }, - { - "description": "A $ref to a $dynamicAnchor in the same schema resource behaves like a normal $ref to an $anchor", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/ref-dynamicAnchor-same-schema/root", - "type": "array", - "items": { "$ref": "#items" }, - "$defs": { - "foo": { - "$dynamicAnchor": "items", - "type": "string" - } - } - }, - "tests": [ - { - "description": "An array of strings is valid", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "An array containing non-strings is invalid", - "data": ["foo", 42], - "valid": false - } - ] - }, - { - "description": "A $dynamicRef resolves to the first $dynamicAnchor still in scope that is encountered when the schema is evaluated", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/typical-dynamic-resolution/root", - "$ref": "list", - "$defs": { - "foo": { - "$dynamicAnchor": "items", - "type": "string" - }, - "list": { - "$id": "list", - "type": "array", - "items": { "$dynamicRef": "#items" }, - "$defs": { - "items": { - "$comment": "This is only needed to satisfy the bookending requirement", - "$dynamicAnchor": "items" - } - } - } - } - }, - "tests": [ - { - "description": "An array of strings is valid", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "An array containing non-strings is invalid", - "data": ["foo", 42], - "valid": false - } - ] - }, - { - "description": "A $dynamicRef without anchor in fragment behaves identical to $ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamicRef-without-anchor/root", - "$ref": "list", - "$defs": { - "foo": { - "$dynamicAnchor": "items", - "type": "string" - }, - "list": { - "$id": "list", - "type": "array", - "items": { "$dynamicRef": "#/$defs/items" }, - "$defs": { - "items": { - "$comment": "This is only needed to satisfy the bookending requirement", - "$dynamicAnchor": "items", - "type": "number" - } - } - } - } - }, - "tests": [ - { - "description": "An array of strings is invalid", - "data": ["foo", "bar"], - "valid": false - }, - { - "description": "An array of numbers is valid", - "data": [24, 42], - "valid": true - } - ] - }, - { - "description": "A $dynamicRef with intermediate scopes that don't include a matching $dynamicAnchor does not affect dynamic scope resolution", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamic-resolution-with-intermediate-scopes/root", - "$ref": "intermediate-scope", - "$defs": { - "foo": { - "$dynamicAnchor": "items", - "type": "string" - }, - "intermediate-scope": { - "$id": "intermediate-scope", - "$ref": "list" - }, - "list": { - "$id": "list", - "type": "array", - "items": { "$dynamicRef": "#items" }, - "$defs": { - "items": { - "$comment": "This is only needed to satisfy the bookending requirement", - "$dynamicAnchor": "items" - } - } - } - } - }, - "tests": [ - { - "description": "An array of strings is valid", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "An array containing non-strings is invalid", - "data": ["foo", 42], - "valid": false - } - ] - }, - { - "description": "An $anchor with the same name as a $dynamicAnchor is not used for dynamic scope resolution", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamic-resolution-ignores-anchors/root", - "$ref": "list", - "$defs": { - "foo": { - "$anchor": "items", - "type": "string" - }, - "list": { - "$id": "list", - "type": "array", - "items": { "$dynamicRef": "#items" }, - "$defs": { - "items": { - "$comment": "This is only needed to satisfy the bookending requirement", - "$dynamicAnchor": "items" - } - } - } - } - }, - "tests": [ - { - "description": "Any array is valid", - "data": ["foo", 42], - "valid": true - } - ] - }, - { - "description": "A $dynamicRef without a matching $dynamicAnchor in the same schema resource behaves like a normal $ref to $anchor", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamic-resolution-without-bookend/root", - "$ref": "list", - "$defs": { - "foo": { - "$dynamicAnchor": "items", - "type": "string" - }, - "list": { - "$id": "list", - "type": "array", - "items": { "$dynamicRef": "#items" }, - "$defs": { - "items": { - "$comment": "This is only needed to give the reference somewhere to resolve to when it behaves like $ref", - "$anchor": "items" - } - } - } - } - }, - "tests": [ - { - "description": "Any array is valid", - "data": ["foo", 42], - "valid": true - } - ] - }, - { - "description": "A $dynamicRef with a non-matching $dynamicAnchor in the same schema resource behaves like a normal $ref to $anchor", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/unmatched-dynamic-anchor/root", - "$ref": "list", - "$defs": { - "foo": { - "$dynamicAnchor": "items", - "type": "string" - }, - "list": { - "$id": "list", - "type": "array", - "items": { "$dynamicRef": "#items" }, - "$defs": { - "items": { - "$comment": "This is only needed to give the reference somewhere to resolve to when it behaves like $ref", - "$anchor": "items", - "$dynamicAnchor": "foo" - } - } - } - } - }, - "tests": [ - { - "description": "Any array is valid", - "data": ["foo", 42], - "valid": true - } - ] - }, - { - "description": "A $dynamicRef that initially resolves to a schema with a matching $dynamicAnchor resolves to the first $dynamicAnchor in the dynamic scope", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/relative-dynamic-reference/root", - "$dynamicAnchor": "meta", - "type": "object", - "properties": { - "foo": { "const": "pass" } - }, - "$ref": "extended", - "$defs": { - "extended": { - "$id": "extended", - "$dynamicAnchor": "meta", - "type": "object", - "properties": { - "bar": { "$ref": "bar" } - } - }, - "bar": { - "$id": "bar", - "type": "object", - "properties": { - "baz": { "$dynamicRef": "extended#meta" } - } - } - } - }, - "tests": [ - { - "description": "The recursive part is valid against the root", - "data": { - "foo": "pass", - "bar": { - "baz": { "foo": "pass" } - } - }, - "valid": true - }, - { - "description": "The recursive part is not valid against the root", - "data": { - "foo": "pass", - "bar": { - "baz": { "foo": "fail" } - } - }, - "valid": false - } - ] - }, - { - "description": "A $dynamicRef that initially resolves to a schema without a matching $dynamicAnchor behaves like a normal $ref to $anchor", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/relative-dynamic-reference-without-bookend/root", - "$dynamicAnchor": "meta", - "type": "object", - "properties": { - "foo": { "const": "pass" } - }, - "$ref": "extended", - "$defs": { - "extended": { - "$id": "extended", - "$anchor": "meta", - "type": "object", - "properties": { - "bar": { "$ref": "bar" } - } - }, - "bar": { - "$id": "bar", - "type": "object", - "properties": { - "baz": { "$dynamicRef": "extended#meta" } - } - } - } - }, - "tests": [ - { - "description": "The recursive part doesn't need to validate against the root", - "data": { - "foo": "pass", - "bar": { - "baz": { "foo": "fail" } - } - }, - "valid": true - } - ] - }, - { - "description": "multiple dynamic paths to the $dynamicRef keyword", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamic-ref-with-multiple-paths/main", - "if": { - "properties": { - "kindOfList": { "const": "numbers" } - }, - "required": ["kindOfList"] - }, - "then": { "$ref": "numberList" }, - "else": { "$ref": "stringList" }, - - "$defs": { - "genericList": { - "$id": "genericList", - "properties": { - "list": { - "items": { "$dynamicRef": "#itemType" } - } - }, - "$defs": { - "defaultItemType": { - "$comment": "Only needed to satisfy bookending requirement", - "$dynamicAnchor": "itemType" - } - } - }, - "numberList": { - "$id": "numberList", - "$defs": { - "itemType": { - "$dynamicAnchor": "itemType", - "type": "number" - } - }, - "$ref": "genericList" - }, - "stringList": { - "$id": "stringList", - "$defs": { - "itemType": { - "$dynamicAnchor": "itemType", - "type": "string" - } - }, - "$ref": "genericList" - } - } - }, - "tests": [ - { - "description": "number list with number values", - "data": { - "kindOfList": "numbers", - "list": [1.1] - }, - "valid": true - }, - { - "description": "number list with string values", - "data": { - "kindOfList": "numbers", - "list": ["foo"] - }, - "valid": false - }, - { - "description": "string list with number values", - "data": { - "kindOfList": "strings", - "list": [1.1] - }, - "valid": false - }, - { - "description": "string list with string values", - "data": { - "kindOfList": "strings", - "list": ["foo"] - }, - "valid": true - } - ] - }, - { - "description": "after leaving a dynamic scope, it is not used by a $dynamicRef", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamic-ref-leaving-dynamic-scope/main", - "if": { - "$id": "first_scope", - "$defs": { - "thingy": { - "$comment": "this is first_scope#thingy", - "$dynamicAnchor": "thingy", - "type": "number" - } - } - }, - "then": { - "$id": "second_scope", - "$ref": "start", - "$defs": { - "thingy": { - "$comment": "this is second_scope#thingy, the final destination of the $dynamicRef", - "$dynamicAnchor": "thingy", - "type": "null" - } - } - }, - "$defs": { - "start": { - "$comment": "this is the landing spot from $ref", - "$id": "start", - "$dynamicRef": "inner_scope#thingy" - }, - "thingy": { - "$comment": "this is the first stop for the $dynamicRef", - "$id": "inner_scope", - "$dynamicAnchor": "thingy", - "type": "string" - } - } - }, - "tests": [ - { - "description": "string matches /$defs/thingy, but the $dynamicRef does not stop here", - "data": "a string", - "valid": false - }, - { - "description": "first_scope is not in dynamic scope for the $dynamicRef", - "data": 42, - "valid": false - }, - { - "description": "/then/$defs/thingy is the final stop for the $dynamicRef", - "data": null, - "valid": true - } - ] - }, - { - "description": "strict-tree schema, guards against misspelled properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/strict-tree.json", - "$dynamicAnchor": "node", - - "$ref": "tree.json", - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "instance with misspelled field", - "data": { - "children": [{ - "daat": 1 - }] - }, - "valid": false - }, - { - "description": "instance with correct field", - "data": { - "children": [{ - "data": 1 - }] - }, - "valid": true - } - ] - }, - { - "description": "tests for implementation dynamic anchor and reference link", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/strict-extendible.json", - "$ref": "extendible-dynamic-ref.json", - "$defs": { - "elements": { - "$dynamicAnchor": "elements", - "properties": { - "a": true - }, - "required": ["a"], - "additionalProperties": false - } - } - }, - "tests": [ - { - "description": "incorrect parent schema", - "data": { - "a": true - }, - "valid": false - }, - { - "description": "incorrect extended schema", - "data": { - "elements": [ - { "b": 1 } - ] - }, - "valid": false - }, - { - "description": "correct extended schema", - "data": { - "elements": [ - { "a": 1 } - ] - }, - "valid": true - } - ] - }, - { - "description": "$ref and $dynamicAnchor are independent of order - $defs first", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/strict-extendible-allof-defs-first.json", - "allOf": [ - { - "$ref": "extendible-dynamic-ref.json" - }, - { - "$defs": { - "elements": { - "$dynamicAnchor": "elements", - "properties": { - "a": true - }, - "required": ["a"], - "additionalProperties": false - } - } - } - ] - }, - "tests": [ - { - "description": "incorrect parent schema", - "data": { - "a": true - }, - "valid": false - }, - { - "description": "incorrect extended schema", - "data": { - "elements": [ - { "b": 1 } - ] - }, - "valid": false - }, - { - "description": "correct extended schema", - "data": { - "elements": [ - { "a": 1 } - ] - }, - "valid": true - } - ] - }, - { - "description": "$ref and $dynamicAnchor are independent of order - $ref first", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/strict-extendible-allof-ref-first.json", - "allOf": [ - { - "$defs": { - "elements": { - "$dynamicAnchor": "elements", - "properties": { - "a": true - }, - "required": ["a"], - "additionalProperties": false - } - } - }, - { - "$ref": "extendible-dynamic-ref.json" - } - ] - }, - "tests": [ - { - "description": "incorrect parent schema", - "data": { - "a": true - }, - "valid": false - }, - { - "description": "incorrect extended schema", - "data": { - "elements": [ - { "b": 1 } - ] - }, - "valid": false - }, - { - "description": "correct extended schema", - "data": { - "elements": [ - { "a": 1 } - ] - }, - "valid": true - } - ] - }, - { - "description": "$ref to $dynamicRef finds detached $dynamicAnchor", - "schema": { - "$ref": "http://localhost:1234/draft2020-12/detached-dynamicref.json#/$defs/foo" - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "non-number is invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "$dynamicRef points to a boolean schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "true": true, - "false": false - }, - "properties": { - "true": { - "$dynamicRef": "#/$defs/true" - }, - "false": { - "$dynamicRef": "#/$defs/false" - } - } - }, - "tests": [ - { - "description": "follow $dynamicRef to a true schema", - "data": { "true": 1 }, - "valid": true - }, - { - "description": "follow $dynamicRef to a false schema", - "data": { "false": 1 }, - "valid": false - } - ] - }, - { - "description": "$dynamicRef skips over intermediate resources - direct reference", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamic-ref-skips-intermediate-resource/main", - "type": "object", - "properties": { - "bar-item": { - "$ref": "item" - } - }, - "$defs": { - "bar": { - "$id": "bar", - "type": "array", - "items": { - "$ref": "item" - }, - "$defs": { - "item": { - "$id": "item", - "type": "object", - "properties": { - "content": { - "$dynamicRef": "#content" - } - }, - "$defs": { - "defaultContent": { - "$dynamicAnchor": "content", - "type": "integer" - } - } - }, - "content": { - "$dynamicAnchor": "content", - "type": "string" - } - } - } - } - }, - "tests": [ - { - "description": "integer property passes", - "data": { "bar-item": { "content": 42 } }, - "valid": true - }, - { - "description": "string property fails", - "data": { "bar-item": { "content": "value" } }, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/enum.json b/jsonschema/testdata/draft2020-12/enum.json deleted file mode 100644 index c8f35eac..00000000 --- a/jsonschema/testdata/draft2020-12/enum.json +++ /dev/null @@ -1,358 +0,0 @@ -[ - { - "description": "simple enum validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [1, 2, 3] - }, - "tests": [ - { - "description": "one of the enum is valid", - "data": 1, - "valid": true - }, - { - "description": "something else is invalid", - "data": 4, - "valid": false - } - ] - }, - { - "description": "heterogeneous enum validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [6, "foo", [], true, {"foo": 12}] - }, - "tests": [ - { - "description": "one of the enum is valid", - "data": [], - "valid": true - }, - { - "description": "something else is invalid", - "data": null, - "valid": false - }, - { - "description": "objects are deep compared", - "data": {"foo": false}, - "valid": false - }, - { - "description": "valid object matches", - "data": {"foo": 12}, - "valid": true - }, - { - "description": "extra properties in object is invalid", - "data": {"foo": 12, "boo": 42}, - "valid": false - } - ] - }, - { - "description": "heterogeneous enum-with-null validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [6, null] - }, - "tests": [ - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "number is valid", - "data": 6, - "valid": true - }, - { - "description": "something else is invalid", - "data": "test", - "valid": false - } - ] - }, - { - "description": "enums in properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type":"object", - "properties": { - "foo": {"enum":["foo"]}, - "bar": {"enum":["bar"]} - }, - "required": ["bar"] - }, - "tests": [ - { - "description": "both properties are valid", - "data": {"foo":"foo", "bar":"bar"}, - "valid": true - }, - { - "description": "wrong foo value", - "data": {"foo":"foot", "bar":"bar"}, - "valid": false - }, - { - "description": "wrong bar value", - "data": {"foo":"foo", "bar":"bart"}, - "valid": false - }, - { - "description": "missing optional property is valid", - "data": {"bar":"bar"}, - "valid": true - }, - { - "description": "missing required property is invalid", - "data": {"foo":"foo"}, - "valid": false - }, - { - "description": "missing all properties is invalid", - "data": {}, - "valid": false - } - ] - }, - { - "description": "enum with escaped characters", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": ["foo\nbar", "foo\rbar"] - }, - "tests": [ - { - "description": "member 1 is valid", - "data": "foo\nbar", - "valid": true - }, - { - "description": "member 2 is valid", - "data": "foo\rbar", - "valid": true - }, - { - "description": "another string is invalid", - "data": "abc", - "valid": false - } - ] - }, - { - "description": "enum with false does not match 0", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [false] - }, - "tests": [ - { - "description": "false is valid", - "data": false, - "valid": true - }, - { - "description": "integer zero is invalid", - "data": 0, - "valid": false - }, - { - "description": "float zero is invalid", - "data": 0.0, - "valid": false - } - ] - }, - { - "description": "enum with [false] does not match [0]", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [[false]] - }, - "tests": [ - { - "description": "[false] is valid", - "data": [false], - "valid": true - }, - { - "description": "[0] is invalid", - "data": [0], - "valid": false - }, - { - "description": "[0.0] is invalid", - "data": [0.0], - "valid": false - } - ] - }, - { - "description": "enum with true does not match 1", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [true] - }, - "tests": [ - { - "description": "true is valid", - "data": true, - "valid": true - }, - { - "description": "integer one is invalid", - "data": 1, - "valid": false - }, - { - "description": "float one is invalid", - "data": 1.0, - "valid": false - } - ] - }, - { - "description": "enum with [true] does not match [1]", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [[true]] - }, - "tests": [ - { - "description": "[true] is valid", - "data": [true], - "valid": true - }, - { - "description": "[1] is invalid", - "data": [1], - "valid": false - }, - { - "description": "[1.0] is invalid", - "data": [1.0], - "valid": false - } - ] - }, - { - "description": "enum with 0 does not match false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [0] - }, - "tests": [ - { - "description": "false is invalid", - "data": false, - "valid": false - }, - { - "description": "integer zero is valid", - "data": 0, - "valid": true - }, - { - "description": "float zero is valid", - "data": 0.0, - "valid": true - } - ] - }, - { - "description": "enum with [0] does not match [false]", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [[0]] - }, - "tests": [ - { - "description": "[false] is invalid", - "data": [false], - "valid": false - }, - { - "description": "[0] is valid", - "data": [0], - "valid": true - }, - { - "description": "[0.0] is valid", - "data": [0.0], - "valid": true - } - ] - }, - { - "description": "enum with 1 does not match true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [1] - }, - "tests": [ - { - "description": "true is invalid", - "data": true, - "valid": false - }, - { - "description": "integer one is valid", - "data": 1, - "valid": true - }, - { - "description": "float one is valid", - "data": 1.0, - "valid": true - } - ] - }, - { - "description": "enum with [1] does not match [true]", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [[1]] - }, - "tests": [ - { - "description": "[true] is invalid", - "data": [true], - "valid": false - }, - { - "description": "[1] is valid", - "data": [1], - "valid": true - }, - { - "description": "[1.0] is valid", - "data": [1.0], - "valid": true - } - ] - }, - { - "description": "nul characters in strings", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [ "hello\u0000there" ] - }, - "tests": [ - { - "description": "match string with nul", - "data": "hello\u0000there", - "valid": true - }, - { - "description": "do not match string lacking nul", - "data": "hellothere", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/exclusiveMaximum.json b/jsonschema/testdata/draft2020-12/exclusiveMaximum.json deleted file mode 100644 index 05db2335..00000000 --- a/jsonschema/testdata/draft2020-12/exclusiveMaximum.json +++ /dev/null @@ -1,31 +0,0 @@ -[ - { - "description": "exclusiveMaximum validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "exclusiveMaximum": 3.0 - }, - "tests": [ - { - "description": "below the exclusiveMaximum is valid", - "data": 2.2, - "valid": true - }, - { - "description": "boundary point is invalid", - "data": 3.0, - "valid": false - }, - { - "description": "above the exclusiveMaximum is invalid", - "data": 3.5, - "valid": false - }, - { - "description": "ignores non-numbers", - "data": "x", - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/exclusiveMinimum.json b/jsonschema/testdata/draft2020-12/exclusiveMinimum.json deleted file mode 100644 index 00af9d7f..00000000 --- a/jsonschema/testdata/draft2020-12/exclusiveMinimum.json +++ /dev/null @@ -1,31 +0,0 @@ -[ - { - "description": "exclusiveMinimum validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "exclusiveMinimum": 1.1 - }, - "tests": [ - { - "description": "above the exclusiveMinimum is valid", - "data": 1.2, - "valid": true - }, - { - "description": "boundary point is invalid", - "data": 1.1, - "valid": false - }, - { - "description": "below the exclusiveMinimum is invalid", - "data": 0.6, - "valid": false - }, - { - "description": "ignores non-numbers", - "data": "x", - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/if-then-else.json b/jsonschema/testdata/draft2020-12/if-then-else.json deleted file mode 100644 index 1c35d7e6..00000000 --- a/jsonschema/testdata/draft2020-12/if-then-else.json +++ /dev/null @@ -1,268 +0,0 @@ -[ - { - "description": "ignore if without then or else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": { - "const": 0 - } - }, - "tests": [ - { - "description": "valid when valid against lone if", - "data": 0, - "valid": true - }, - { - "description": "valid when invalid against lone if", - "data": "hello", - "valid": true - } - ] - }, - { - "description": "ignore then without if", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "then": { - "const": 0 - } - }, - "tests": [ - { - "description": "valid when valid against lone then", - "data": 0, - "valid": true - }, - { - "description": "valid when invalid against lone then", - "data": "hello", - "valid": true - } - ] - }, - { - "description": "ignore else without if", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "else": { - "const": 0 - } - }, - "tests": [ - { - "description": "valid when valid against lone else", - "data": 0, - "valid": true - }, - { - "description": "valid when invalid against lone else", - "data": "hello", - "valid": true - } - ] - }, - { - "description": "if and then without else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": { - "exclusiveMaximum": 0 - }, - "then": { - "minimum": -10 - } - }, - "tests": [ - { - "description": "valid through then", - "data": -1, - "valid": true - }, - { - "description": "invalid through then", - "data": -100, - "valid": false - }, - { - "description": "valid when if test fails", - "data": 3, - "valid": true - } - ] - }, - { - "description": "if and else without then", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": { - "exclusiveMaximum": 0 - }, - "else": { - "multipleOf": 2 - } - }, - "tests": [ - { - "description": "valid when if test passes", - "data": -1, - "valid": true - }, - { - "description": "valid through else", - "data": 4, - "valid": true - }, - { - "description": "invalid through else", - "data": 3, - "valid": false - } - ] - }, - { - "description": "validate against correct branch, then vs else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": { - "exclusiveMaximum": 0 - }, - "then": { - "minimum": -10 - }, - "else": { - "multipleOf": 2 - } - }, - "tests": [ - { - "description": "valid through then", - "data": -1, - "valid": true - }, - { - "description": "invalid through then", - "data": -100, - "valid": false - }, - { - "description": "valid through else", - "data": 4, - "valid": true - }, - { - "description": "invalid through else", - "data": 3, - "valid": false - } - ] - }, - { - "description": "non-interference across combined schemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "if": { - "exclusiveMaximum": 0 - } - }, - { - "then": { - "minimum": -10 - } - }, - { - "else": { - "multipleOf": 2 - } - } - ] - }, - "tests": [ - { - "description": "valid, but would have been invalid through then", - "data": -100, - "valid": true - }, - { - "description": "valid, but would have been invalid through else", - "data": 3, - "valid": true - } - ] - }, - { - "description": "if with boolean schema true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": true, - "then": { "const": "then" }, - "else": { "const": "else" } - }, - "tests": [ - { - "description": "boolean schema true in if always chooses the then path (valid)", - "data": "then", - "valid": true - }, - { - "description": "boolean schema true in if always chooses the then path (invalid)", - "data": "else", - "valid": false - } - ] - }, - { - "description": "if with boolean schema false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": false, - "then": { "const": "then" }, - "else": { "const": "else" } - }, - "tests": [ - { - "description": "boolean schema false in if always chooses the else path (invalid)", - "data": "then", - "valid": false - }, - { - "description": "boolean schema false in if always chooses the else path (valid)", - "data": "else", - "valid": true - } - ] - }, - { - "description": "if appears at the end when serialized (keyword processing sequence)", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "then": { "const": "yes" }, - "else": { "const": "other" }, - "if": { "maxLength": 4 } - }, - "tests": [ - { - "description": "yes redirects to then and passes", - "data": "yes", - "valid": true - }, - { - "description": "other redirects to else and passes", - "data": "other", - "valid": true - }, - { - "description": "no redirects to then and fails", - "data": "no", - "valid": false - }, - { - "description": "invalid redirects to else and fails", - "data": "invalid", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/infinite-loop-detection.json b/jsonschema/testdata/draft2020-12/infinite-loop-detection.json deleted file mode 100644 index 46f157a3..00000000 --- a/jsonschema/testdata/draft2020-12/infinite-loop-detection.json +++ /dev/null @@ -1,37 +0,0 @@ -[ - { - "description": "evaluating the same schema location against the same data location twice is not a sign of an infinite loop", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "int": { "type": "integer" } - }, - "allOf": [ - { - "properties": { - "foo": { - "$ref": "#/$defs/int" - } - } - }, - { - "additionalProperties": { - "$ref": "#/$defs/int" - } - } - ] - }, - "tests": [ - { - "description": "passing case", - "data": { "foo": 1 }, - "valid": true - }, - { - "description": "failing case", - "data": { "foo": "a string" }, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/items.json b/jsonschema/testdata/draft2020-12/items.json deleted file mode 100644 index 6a3e1cf2..00000000 --- a/jsonschema/testdata/draft2020-12/items.json +++ /dev/null @@ -1,304 +0,0 @@ -[ - { - "description": "a schema given for items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "items": {"type": "integer"} - }, - "tests": [ - { - "description": "valid items", - "data": [ 1, 2, 3 ], - "valid": true - }, - { - "description": "wrong type of items", - "data": [1, "x"], - "valid": false - }, - { - "description": "ignores non-arrays", - "data": {"foo" : "bar"}, - "valid": true - }, - { - "description": "JavaScript pseudo-array is valid", - "data": { - "0": "invalid", - "length": 1 - }, - "valid": true - } - ] - }, - { - "description": "items with boolean schema (true)", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "items": true - }, - "tests": [ - { - "description": "any array is valid", - "data": [ 1, "foo", true ], - "valid": true - }, - { - "description": "empty array is valid", - "data": [], - "valid": true - } - ] - }, - { - "description": "items with boolean schema (false)", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "items": false - }, - "tests": [ - { - "description": "any non-empty array is invalid", - "data": [ 1, "foo", true ], - "valid": false - }, - { - "description": "empty array is valid", - "data": [], - "valid": true - } - ] - }, - { - "description": "items and subitems", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "item": { - "type": "array", - "items": false, - "prefixItems": [ - { "$ref": "#/$defs/sub-item" }, - { "$ref": "#/$defs/sub-item" } - ] - }, - "sub-item": { - "type": "object", - "required": ["foo"] - } - }, - "type": "array", - "items": false, - "prefixItems": [ - { "$ref": "#/$defs/item" }, - { "$ref": "#/$defs/item" }, - { "$ref": "#/$defs/item" } - ] - }, - "tests": [ - { - "description": "valid items", - "data": [ - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ] - ], - "valid": true - }, - { - "description": "too many items", - "data": [ - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ] - ], - "valid": false - }, - { - "description": "too many sub-items", - "data": [ - [ {"foo": null}, {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ] - ], - "valid": false - }, - { - "description": "wrong item", - "data": [ - {"foo": null}, - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ] - ], - "valid": false - }, - { - "description": "wrong sub-item", - "data": [ - [ {}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ] - ], - "valid": false - }, - { - "description": "fewer items is valid", - "data": [ - [ {"foo": null} ], - [ {"foo": null} ] - ], - "valid": true - } - ] - }, - { - "description": "nested items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "array", - "items": { - "type": "array", - "items": { - "type": "array", - "items": { - "type": "array", - "items": { - "type": "number" - } - } - } - } - }, - "tests": [ - { - "description": "valid nested array", - "data": [[[[1]], [[2],[3]]], [[[4], [5], [6]]]], - "valid": true - }, - { - "description": "nested array with invalid type", - "data": [[[["1"]], [[2],[3]]], [[[4], [5], [6]]]], - "valid": false - }, - { - "description": "not deep enough", - "data": [[[1], [2],[3]], [[4], [5], [6]]], - "valid": false - } - ] - }, - { - "description": "prefixItems with no additional items allowed", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [{}, {}, {}], - "items": false - }, - "tests": [ - { - "description": "empty array", - "data": [ ], - "valid": true - }, - { - "description": "fewer number of items present (1)", - "data": [ 1 ], - "valid": true - }, - { - "description": "fewer number of items present (2)", - "data": [ 1, 2 ], - "valid": true - }, - { - "description": "equal number of items present", - "data": [ 1, 2, 3 ], - "valid": true - }, - { - "description": "additional items are not permitted", - "data": [ 1, 2, 3, 4 ], - "valid": false - } - ] - }, - { - "description": "items does not look in applicators, valid case", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { "prefixItems": [ { "minimum": 3 } ] } - ], - "items": { "minimum": 5 } - }, - "tests": [ - { - "description": "prefixItems in allOf does not constrain items, invalid case", - "data": [ 3, 5 ], - "valid": false - }, - { - "description": "prefixItems in allOf does not constrain items, valid case", - "data": [ 5, 5 ], - "valid": true - } - ] - }, - { - "description": "prefixItems validation adjusts the starting index for items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ { "type": "string" } ], - "items": { "type": "integer" } - }, - "tests": [ - { - "description": "valid items", - "data": [ "x", 2, 3 ], - "valid": true - }, - { - "description": "wrong type of second item", - "data": [ "x", "y" ], - "valid": false - } - ] - }, - { - "description": "items with heterogeneous array", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [{}], - "items": false - }, - "tests": [ - { - "description": "heterogeneous invalid instance", - "data": [ "foo", "bar", 37 ], - "valid": false - }, - { - "description": "valid instance", - "data": [ null ], - "valid": true - } - ] - }, - { - "description": "items with null instance elements", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "items": { - "type": "null" - } - }, - "tests": [ - { - "description": "allows null elements", - "data": [ null ], - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/maxContains.json b/jsonschema/testdata/draft2020-12/maxContains.json deleted file mode 100644 index 8cd3ca74..00000000 --- a/jsonschema/testdata/draft2020-12/maxContains.json +++ /dev/null @@ -1,102 +0,0 @@ -[ - { - "description": "maxContains without contains is ignored", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxContains": 1 - }, - "tests": [ - { - "description": "one item valid against lone maxContains", - "data": [ 1 ], - "valid": true - }, - { - "description": "two items still valid against lone maxContains", - "data": [ 1, 2 ], - "valid": true - } - ] - }, - { - "description": "maxContains with contains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "maxContains": 1 - }, - "tests": [ - { - "description": "empty data", - "data": [ ], - "valid": false - }, - { - "description": "all elements match, valid maxContains", - "data": [ 1 ], - "valid": true - }, - { - "description": "all elements match, invalid maxContains", - "data": [ 1, 1 ], - "valid": false - }, - { - "description": "some elements match, valid maxContains", - "data": [ 1, 2 ], - "valid": true - }, - { - "description": "some elements match, invalid maxContains", - "data": [ 1, 2, 1 ], - "valid": false - } - ] - }, - { - "description": "maxContains with contains, value with a decimal", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "maxContains": 1.0 - }, - "tests": [ - { - "description": "one element matches, valid maxContains", - "data": [ 1 ], - "valid": true - }, - { - "description": "too many elements match, invalid maxContains", - "data": [ 1, 1 ], - "valid": false - } - ] - }, - { - "description": "minContains < maxContains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "minContains": 1, - "maxContains": 3 - }, - "tests": [ - { - "description": "actual < minContains < maxContains", - "data": [ ], - "valid": false - }, - { - "description": "minContains < actual < maxContains", - "data": [ 1, 1 ], - "valid": true - }, - { - "description": "minContains < maxContains < actual", - "data": [ 1, 1, 1, 1 ], - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/maxItems.json b/jsonschema/testdata/draft2020-12/maxItems.json deleted file mode 100644 index f6a6b7c9..00000000 --- a/jsonschema/testdata/draft2020-12/maxItems.json +++ /dev/null @@ -1,50 +0,0 @@ -[ - { - "description": "maxItems validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxItems": 2 - }, - "tests": [ - { - "description": "shorter is valid", - "data": [1], - "valid": true - }, - { - "description": "exact length is valid", - "data": [1, 2], - "valid": true - }, - { - "description": "too long is invalid", - "data": [1, 2, 3], - "valid": false - }, - { - "description": "ignores non-arrays", - "data": "foobar", - "valid": true - } - ] - }, - { - "description": "maxItems validation with a decimal", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxItems": 2.0 - }, - "tests": [ - { - "description": "shorter is valid", - "data": [1], - "valid": true - }, - { - "description": "too long is invalid", - "data": [1, 2, 3], - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/maxLength.json b/jsonschema/testdata/draft2020-12/maxLength.json deleted file mode 100644 index 7462726d..00000000 --- a/jsonschema/testdata/draft2020-12/maxLength.json +++ /dev/null @@ -1,55 +0,0 @@ -[ - { - "description": "maxLength validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxLength": 2 - }, - "tests": [ - { - "description": "shorter is valid", - "data": "f", - "valid": true - }, - { - "description": "exact length is valid", - "data": "fo", - "valid": true - }, - { - "description": "too long is invalid", - "data": "foo", - "valid": false - }, - { - "description": "ignores non-strings", - "data": 100, - "valid": true - }, - { - "description": "two graphemes is long enough", - "data": "\uD83D\uDCA9\uD83D\uDCA9", - "valid": true - } - ] - }, - { - "description": "maxLength validation with a decimal", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxLength": 2.0 - }, - "tests": [ - { - "description": "shorter is valid", - "data": "f", - "valid": true - }, - { - "description": "too long is invalid", - "data": "foo", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/maxProperties.json b/jsonschema/testdata/draft2020-12/maxProperties.json deleted file mode 100644 index 73ae7316..00000000 --- a/jsonschema/testdata/draft2020-12/maxProperties.json +++ /dev/null @@ -1,79 +0,0 @@ -[ - { - "description": "maxProperties validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxProperties": 2 - }, - "tests": [ - { - "description": "shorter is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "exact length is valid", - "data": {"foo": 1, "bar": 2}, - "valid": true - }, - { - "description": "too long is invalid", - "data": {"foo": 1, "bar": 2, "baz": 3}, - "valid": false - }, - { - "description": "ignores arrays", - "data": [1, 2, 3], - "valid": true - }, - { - "description": "ignores strings", - "data": "foobar", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": "maxProperties validation with a decimal", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxProperties": 2.0 - }, - "tests": [ - { - "description": "shorter is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "too long is invalid", - "data": {"foo": 1, "bar": 2, "baz": 3}, - "valid": false - } - ] - }, - { - "description": "maxProperties = 0 means the object is empty", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxProperties": 0 - }, - "tests": [ - { - "description": "no properties is valid", - "data": {}, - "valid": true - }, - { - "description": "one property is invalid", - "data": { "foo": 1 }, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/maximum.json b/jsonschema/testdata/draft2020-12/maximum.json deleted file mode 100644 index b99a541e..00000000 --- a/jsonschema/testdata/draft2020-12/maximum.json +++ /dev/null @@ -1,60 +0,0 @@ -[ - { - "description": "maximum validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maximum": 3.0 - }, - "tests": [ - { - "description": "below the maximum is valid", - "data": 2.6, - "valid": true - }, - { - "description": "boundary point is valid", - "data": 3.0, - "valid": true - }, - { - "description": "above the maximum is invalid", - "data": 3.5, - "valid": false - }, - { - "description": "ignores non-numbers", - "data": "x", - "valid": true - } - ] - }, - { - "description": "maximum validation with unsigned integer", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maximum": 300 - }, - "tests": [ - { - "description": "below the maximum is invalid", - "data": 299.97, - "valid": true - }, - { - "description": "boundary point integer is valid", - "data": 300, - "valid": true - }, - { - "description": "boundary point float is valid", - "data": 300.00, - "valid": true - }, - { - "description": "above the maximum is invalid", - "data": 300.5, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/minContains.json b/jsonschema/testdata/draft2020-12/minContains.json deleted file mode 100644 index ee72d7d6..00000000 --- a/jsonschema/testdata/draft2020-12/minContains.json +++ /dev/null @@ -1,224 +0,0 @@ -[ - { - "description": "minContains without contains is ignored", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minContains": 1 - }, - "tests": [ - { - "description": "one item valid against lone minContains", - "data": [ 1 ], - "valid": true - }, - { - "description": "zero items still valid against lone minContains", - "data": [], - "valid": true - } - ] - }, - { - "description": "minContains=1 with contains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "minContains": 1 - }, - "tests": [ - { - "description": "empty data", - "data": [ ], - "valid": false - }, - { - "description": "no elements match", - "data": [ 2 ], - "valid": false - }, - { - "description": "single element matches, valid minContains", - "data": [ 1 ], - "valid": true - }, - { - "description": "some elements match, valid minContains", - "data": [ 1, 2 ], - "valid": true - }, - { - "description": "all elements match, valid minContains", - "data": [ 1, 1 ], - "valid": true - } - ] - }, - { - "description": "minContains=2 with contains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "minContains": 2 - }, - "tests": [ - { - "description": "empty data", - "data": [ ], - "valid": false - }, - { - "description": "all elements match, invalid minContains", - "data": [ 1 ], - "valid": false - }, - { - "description": "some elements match, invalid minContains", - "data": [ 1, 2 ], - "valid": false - }, - { - "description": "all elements match, valid minContains (exactly as needed)", - "data": [ 1, 1 ], - "valid": true - }, - { - "description": "all elements match, valid minContains (more than needed)", - "data": [ 1, 1, 1 ], - "valid": true - }, - { - "description": "some elements match, valid minContains", - "data": [ 1, 2, 1 ], - "valid": true - } - ] - }, - { - "description": "minContains=2 with contains with a decimal value", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "minContains": 2.0 - }, - "tests": [ - { - "description": "one element matches, invalid minContains", - "data": [ 1 ], - "valid": false - }, - { - "description": "both elements match, valid minContains", - "data": [ 1, 1 ], - "valid": true - } - ] - }, - { - "description": "maxContains = minContains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "maxContains": 2, - "minContains": 2 - }, - "tests": [ - { - "description": "empty data", - "data": [ ], - "valid": false - }, - { - "description": "all elements match, invalid minContains", - "data": [ 1 ], - "valid": false - }, - { - "description": "all elements match, invalid maxContains", - "data": [ 1, 1, 1 ], - "valid": false - }, - { - "description": "all elements match, valid maxContains and minContains", - "data": [ 1, 1 ], - "valid": true - } - ] - }, - { - "description": "maxContains < minContains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "maxContains": 1, - "minContains": 3 - }, - "tests": [ - { - "description": "empty data", - "data": [ ], - "valid": false - }, - { - "description": "invalid minContains", - "data": [ 1 ], - "valid": false - }, - { - "description": "invalid maxContains", - "data": [ 1, 1, 1 ], - "valid": false - }, - { - "description": "invalid maxContains and minContains", - "data": [ 1, 1 ], - "valid": false - } - ] - }, - { - "description": "minContains = 0", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "minContains": 0 - }, - "tests": [ - { - "description": "empty data", - "data": [ ], - "valid": true - }, - { - "description": "minContains = 0 makes contains always pass", - "data": [ 2 ], - "valid": true - } - ] - }, - { - "description": "minContains = 0 with maxContains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "minContains": 0, - "maxContains": 1 - }, - "tests": [ - { - "description": "empty data", - "data": [ ], - "valid": true - }, - { - "description": "not more than maxContains", - "data": [ 1 ], - "valid": true - }, - { - "description": "too many", - "data": [ 1, 1 ], - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/minItems.json b/jsonschema/testdata/draft2020-12/minItems.json deleted file mode 100644 index 9d6a8b6d..00000000 --- a/jsonschema/testdata/draft2020-12/minItems.json +++ /dev/null @@ -1,50 +0,0 @@ -[ - { - "description": "minItems validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minItems": 1 - }, - "tests": [ - { - "description": "longer is valid", - "data": [1, 2], - "valid": true - }, - { - "description": "exact length is valid", - "data": [1], - "valid": true - }, - { - "description": "too short is invalid", - "data": [], - "valid": false - }, - { - "description": "ignores non-arrays", - "data": "", - "valid": true - } - ] - }, - { - "description": "minItems validation with a decimal", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minItems": 1.0 - }, - "tests": [ - { - "description": "longer is valid", - "data": [1, 2], - "valid": true - }, - { - "description": "too short is invalid", - "data": [], - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/minLength.json b/jsonschema/testdata/draft2020-12/minLength.json deleted file mode 100644 index 5076c5a9..00000000 --- a/jsonschema/testdata/draft2020-12/minLength.json +++ /dev/null @@ -1,55 +0,0 @@ -[ - { - "description": "minLength validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minLength": 2 - }, - "tests": [ - { - "description": "longer is valid", - "data": "foo", - "valid": true - }, - { - "description": "exact length is valid", - "data": "fo", - "valid": true - }, - { - "description": "too short is invalid", - "data": "f", - "valid": false - }, - { - "description": "ignores non-strings", - "data": 1, - "valid": true - }, - { - "description": "one grapheme is not long enough", - "data": "\uD83D\uDCA9", - "valid": false - } - ] - }, - { - "description": "minLength validation with a decimal", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minLength": 2.0 - }, - "tests": [ - { - "description": "longer is valid", - "data": "foo", - "valid": true - }, - { - "description": "too short is invalid", - "data": "f", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/minProperties.json b/jsonschema/testdata/draft2020-12/minProperties.json deleted file mode 100644 index a753ad35..00000000 --- a/jsonschema/testdata/draft2020-12/minProperties.json +++ /dev/null @@ -1,60 +0,0 @@ -[ - { - "description": "minProperties validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minProperties": 1 - }, - "tests": [ - { - "description": "longer is valid", - "data": {"foo": 1, "bar": 2}, - "valid": true - }, - { - "description": "exact length is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "too short is invalid", - "data": {}, - "valid": false - }, - { - "description": "ignores arrays", - "data": [], - "valid": true - }, - { - "description": "ignores strings", - "data": "", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": "minProperties validation with a decimal", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minProperties": 1.0 - }, - "tests": [ - { - "description": "longer is valid", - "data": {"foo": 1, "bar": 2}, - "valid": true - }, - { - "description": "too short is invalid", - "data": {}, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/minimum.json b/jsonschema/testdata/draft2020-12/minimum.json deleted file mode 100644 index dc440527..00000000 --- a/jsonschema/testdata/draft2020-12/minimum.json +++ /dev/null @@ -1,75 +0,0 @@ -[ - { - "description": "minimum validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minimum": 1.1 - }, - "tests": [ - { - "description": "above the minimum is valid", - "data": 2.6, - "valid": true - }, - { - "description": "boundary point is valid", - "data": 1.1, - "valid": true - }, - { - "description": "below the minimum is invalid", - "data": 0.6, - "valid": false - }, - { - "description": "ignores non-numbers", - "data": "x", - "valid": true - } - ] - }, - { - "description": "minimum validation with signed integer", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minimum": -2 - }, - "tests": [ - { - "description": "negative above the minimum is valid", - "data": -1, - "valid": true - }, - { - "description": "positive above the minimum is valid", - "data": 0, - "valid": true - }, - { - "description": "boundary point is valid", - "data": -2, - "valid": true - }, - { - "description": "boundary point with float is valid", - "data": -2.0, - "valid": true - }, - { - "description": "float below the minimum is invalid", - "data": -2.0001, - "valid": false - }, - { - "description": "int below the minimum is invalid", - "data": -3, - "valid": false - }, - { - "description": "ignores non-numbers", - "data": "x", - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/multipleOf.json b/jsonschema/testdata/draft2020-12/multipleOf.json deleted file mode 100644 index 92d6979b..00000000 --- a/jsonschema/testdata/draft2020-12/multipleOf.json +++ /dev/null @@ -1,97 +0,0 @@ -[ - { - "description": "by int", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "multipleOf": 2 - }, - "tests": [ - { - "description": "int by int", - "data": 10, - "valid": true - }, - { - "description": "int by int fail", - "data": 7, - "valid": false - }, - { - "description": "ignores non-numbers", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "by number", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "multipleOf": 1.5 - }, - "tests": [ - { - "description": "zero is multiple of anything", - "data": 0, - "valid": true - }, - { - "description": "4.5 is multiple of 1.5", - "data": 4.5, - "valid": true - }, - { - "description": "35 is not multiple of 1.5", - "data": 35, - "valid": false - } - ] - }, - { - "description": "by small number", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "multipleOf": 0.0001 - }, - "tests": [ - { - "description": "0.0075 is multiple of 0.0001", - "data": 0.0075, - "valid": true - }, - { - "description": "0.00751 is not multiple of 0.0001", - "data": 0.00751, - "valid": false - } - ] - }, - { - "description": "float division = inf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "integer", "multipleOf": 0.123456789 - }, - "tests": [ - { - "description": "always invalid, but naive implementations may raise an overflow error", - "data": 1e308, - "valid": false - } - ] - }, - { - "description": "small multiple of large integer", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "integer", "multipleOf": 1e-8 - }, - "tests": [ - { - "description": "any integer is a multiple of 1e-8", - "data": 12391239123, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/not.json b/jsonschema/testdata/draft2020-12/not.json deleted file mode 100644 index 346d4a7e..00000000 --- a/jsonschema/testdata/draft2020-12/not.json +++ /dev/null @@ -1,301 +0,0 @@ -[ - { - "description": "not", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": {"type": "integer"} - }, - "tests": [ - { - "description": "allowed", - "data": "foo", - "valid": true - }, - { - "description": "disallowed", - "data": 1, - "valid": false - } - ] - }, - { - "description": "not multiple types", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": {"type": ["integer", "boolean"]} - }, - "tests": [ - { - "description": "valid", - "data": "foo", - "valid": true - }, - { - "description": "mismatch", - "data": 1, - "valid": false - }, - { - "description": "other mismatch", - "data": true, - "valid": false - } - ] - }, - { - "description": "not more complex schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": { - "type": "object", - "properties": { - "foo": { - "type": "string" - } - } - } - }, - "tests": [ - { - "description": "match", - "data": 1, - "valid": true - }, - { - "description": "other match", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "mismatch", - "data": {"foo": "bar"}, - "valid": false - } - ] - }, - { - "description": "forbidden property", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": { - "not": {} - } - } - }, - "tests": [ - { - "description": "property present", - "data": {"foo": 1, "bar": 2}, - "valid": false - }, - { - "description": "property absent", - "data": {"bar": 1, "baz": 2}, - "valid": true - } - ] - }, - { - "description": "forbid everything with empty schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": {} - }, - "tests": [ - { - "description": "number is invalid", - "data": 1, - "valid": false - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - }, - { - "description": "boolean true is invalid", - "data": true, - "valid": false - }, - { - "description": "boolean false is invalid", - "data": false, - "valid": false - }, - { - "description": "null is invalid", - "data": null, - "valid": false - }, - { - "description": "object is invalid", - "data": {"foo": "bar"}, - "valid": false - }, - { - "description": "empty object is invalid", - "data": {}, - "valid": false - }, - { - "description": "array is invalid", - "data": ["foo"], - "valid": false - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - } - ] - }, - { - "description": "forbid everything with boolean schema true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": true - }, - "tests": [ - { - "description": "number is invalid", - "data": 1, - "valid": false - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - }, - { - "description": "boolean true is invalid", - "data": true, - "valid": false - }, - { - "description": "boolean false is invalid", - "data": false, - "valid": false - }, - { - "description": "null is invalid", - "data": null, - "valid": false - }, - { - "description": "object is invalid", - "data": {"foo": "bar"}, - "valid": false - }, - { - "description": "empty object is invalid", - "data": {}, - "valid": false - }, - { - "description": "array is invalid", - "data": ["foo"], - "valid": false - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - } - ] - }, - { - "description": "allow everything with boolean schema false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": false - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "string is valid", - "data": "foo", - "valid": true - }, - { - "description": "boolean true is valid", - "data": true, - "valid": true - }, - { - "description": "boolean false is valid", - "data": false, - "valid": true - }, - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "object is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - }, - { - "description": "array is valid", - "data": ["foo"], - "valid": true - }, - { - "description": "empty array is valid", - "data": [], - "valid": true - } - ] - }, - { - "description": "double negation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": { "not": {} } - }, - "tests": [ - { - "description": "any value is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "collect annotations inside a 'not', even if collection is disabled", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": { - "$comment": "this subschema must still produce annotations internally, even though the 'not' will ultimately discard them", - "anyOf": [ - true, - { "properties": { "foo": true } } - ], - "unevaluatedProperties": false - } - }, - "tests": [ - { - "description": "unevaluated property", - "data": { "bar": 1 }, - "valid": true - }, - { - "description": "annotations are still collected inside a 'not'", - "data": { "foo": 1 }, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/oneOf.json b/jsonschema/testdata/draft2020-12/oneOf.json deleted file mode 100644 index 7a7c7ffe..00000000 --- a/jsonschema/testdata/draft2020-12/oneOf.json +++ /dev/null @@ -1,293 +0,0 @@ -[ - { - "description": "oneOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [ - { - "type": "integer" - }, - { - "minimum": 2 - } - ] - }, - "tests": [ - { - "description": "first oneOf valid", - "data": 1, - "valid": true - }, - { - "description": "second oneOf valid", - "data": 2.5, - "valid": true - }, - { - "description": "both oneOf valid", - "data": 3, - "valid": false - }, - { - "description": "neither oneOf valid", - "data": 1.5, - "valid": false - } - ] - }, - { - "description": "oneOf with base schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "string", - "oneOf" : [ - { - "minLength": 2 - }, - { - "maxLength": 4 - } - ] - }, - "tests": [ - { - "description": "mismatch base schema", - "data": 3, - "valid": false - }, - { - "description": "one oneOf valid", - "data": "foobar", - "valid": true - }, - { - "description": "both oneOf valid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "oneOf with boolean schemas, all true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [true, true, true] - }, - "tests": [ - { - "description": "any value is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "oneOf with boolean schemas, one true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [true, false, false] - }, - "tests": [ - { - "description": "any value is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "oneOf with boolean schemas, more than one true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [true, true, false] - }, - "tests": [ - { - "description": "any value is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "oneOf with boolean schemas, all false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [false, false, false] - }, - "tests": [ - { - "description": "any value is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "oneOf complex types", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [ - { - "properties": { - "bar": {"type": "integer"} - }, - "required": ["bar"] - }, - { - "properties": { - "foo": {"type": "string"} - }, - "required": ["foo"] - } - ] - }, - "tests": [ - { - "description": "first oneOf valid (complex)", - "data": {"bar": 2}, - "valid": true - }, - { - "description": "second oneOf valid (complex)", - "data": {"foo": "baz"}, - "valid": true - }, - { - "description": "both oneOf valid (complex)", - "data": {"foo": "baz", "bar": 2}, - "valid": false - }, - { - "description": "neither oneOf valid (complex)", - "data": {"foo": 2, "bar": "quux"}, - "valid": false - } - ] - }, - { - "description": "oneOf with empty schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [ - { "type": "number" }, - {} - ] - }, - "tests": [ - { - "description": "one valid - valid", - "data": "foo", - "valid": true - }, - { - "description": "both valid - invalid", - "data": 123, - "valid": false - } - ] - }, - { - "description": "oneOf with required", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "oneOf": [ - { "required": ["foo", "bar"] }, - { "required": ["foo", "baz"] } - ] - }, - "tests": [ - { - "description": "both invalid - invalid", - "data": {"bar": 2}, - "valid": false - }, - { - "description": "first valid - valid", - "data": {"foo": 1, "bar": 2}, - "valid": true - }, - { - "description": "second valid - valid", - "data": {"foo": 1, "baz": 3}, - "valid": true - }, - { - "description": "both valid - invalid", - "data": {"foo": 1, "bar": 2, "baz" : 3}, - "valid": false - } - ] - }, - { - "description": "oneOf with missing optional property", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [ - { - "properties": { - "bar": true, - "baz": true - }, - "required": ["bar"] - }, - { - "properties": { - "foo": true - }, - "required": ["foo"] - } - ] - }, - "tests": [ - { - "description": "first oneOf valid", - "data": {"bar": 8}, - "valid": true - }, - { - "description": "second oneOf valid", - "data": {"foo": "foo"}, - "valid": true - }, - { - "description": "both oneOf valid", - "data": {"foo": "foo", "bar": 8}, - "valid": false - }, - { - "description": "neither oneOf valid", - "data": {"baz": "quux"}, - "valid": false - } - ] - }, - { - "description": "nested oneOf, to check validation semantics", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [ - { - "oneOf": [ - { - "type": "null" - } - ] - } - ] - }, - "tests": [ - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "anything non-null is invalid", - "data": 123, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/pattern.json b/jsonschema/testdata/draft2020-12/pattern.json deleted file mode 100644 index af0b8d89..00000000 --- a/jsonschema/testdata/draft2020-12/pattern.json +++ /dev/null @@ -1,65 +0,0 @@ -[ - { - "description": "pattern validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "pattern": "^a*$" - }, - "tests": [ - { - "description": "a matching pattern is valid", - "data": "aaa", - "valid": true - }, - { - "description": "a non-matching pattern is invalid", - "data": "abc", - "valid": false - }, - { - "description": "ignores booleans", - "data": true, - "valid": true - }, - { - "description": "ignores integers", - "data": 123, - "valid": true - }, - { - "description": "ignores floats", - "data": 1.0, - "valid": true - }, - { - "description": "ignores objects", - "data": {}, - "valid": true - }, - { - "description": "ignores arrays", - "data": [], - "valid": true - }, - { - "description": "ignores null", - "data": null, - "valid": true - } - ] - }, - { - "description": "pattern is not anchored", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "pattern": "a+" - }, - "tests": [ - { - "description": "matches a substring", - "data": "xxaayy", - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/patternProperties.json b/jsonschema/testdata/draft2020-12/patternProperties.json deleted file mode 100644 index 81829c71..00000000 --- a/jsonschema/testdata/draft2020-12/patternProperties.json +++ /dev/null @@ -1,176 +0,0 @@ -[ - { - "description": - "patternProperties validates properties matching a regex", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "patternProperties": { - "f.*o": {"type": "integer"} - } - }, - "tests": [ - { - "description": "a single valid match is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "multiple valid matches is valid", - "data": {"foo": 1, "foooooo" : 2}, - "valid": true - }, - { - "description": "a single invalid match is invalid", - "data": {"foo": "bar", "fooooo": 2}, - "valid": false - }, - { - "description": "multiple invalid matches is invalid", - "data": {"foo": "bar", "foooooo" : "baz"}, - "valid": false - }, - { - "description": "ignores arrays", - "data": ["foo"], - "valid": true - }, - { - "description": "ignores strings", - "data": "foo", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": "multiple simultaneous patternProperties are validated", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "patternProperties": { - "a*": {"type": "integer"}, - "aaa*": {"maximum": 20} - } - }, - "tests": [ - { - "description": "a single valid match is valid", - "data": {"a": 21}, - "valid": true - }, - { - "description": "a simultaneous match is valid", - "data": {"aaaa": 18}, - "valid": true - }, - { - "description": "multiple matches is valid", - "data": {"a": 21, "aaaa": 18}, - "valid": true - }, - { - "description": "an invalid due to one is invalid", - "data": {"a": "bar"}, - "valid": false - }, - { - "description": "an invalid due to the other is invalid", - "data": {"aaaa": 31}, - "valid": false - }, - { - "description": "an invalid due to both is invalid", - "data": {"aaa": "foo", "aaaa": 31}, - "valid": false - } - ] - }, - { - "description": "regexes are not anchored by default and are case sensitive", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "patternProperties": { - "[0-9]{2,}": { "type": "boolean" }, - "X_": { "type": "string" } - } - }, - "tests": [ - { - "description": "non recognized members are ignored", - "data": { "answer 1": "42" }, - "valid": true - }, - { - "description": "recognized members are accounted for", - "data": { "a31b": null }, - "valid": false - }, - { - "description": "regexes are case sensitive", - "data": { "a_x_3": 3 }, - "valid": true - }, - { - "description": "regexes are case sensitive, 2", - "data": { "a_X_3": 3 }, - "valid": false - } - ] - }, - { - "description": "patternProperties with boolean schemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "patternProperties": { - "f.*": true, - "b.*": false - } - }, - "tests": [ - { - "description": "object with property matching schema true is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "object with property matching schema false is invalid", - "data": {"bar": 2}, - "valid": false - }, - { - "description": "object with both properties is invalid", - "data": {"foo": 1, "bar": 2}, - "valid": false - }, - { - "description": "object with a property matching both true and false is invalid", - "data": {"foobar":1}, - "valid": false - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - } - ] - }, - { - "description": "patternProperties with null valued instance properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "patternProperties": { - "^.*bar$": {"type": "null"} - } - }, - "tests": [ - { - "description": "allows null values", - "data": {"foobar": null}, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/prefixItems.json b/jsonschema/testdata/draft2020-12/prefixItems.json deleted file mode 100644 index 0adfc069..00000000 --- a/jsonschema/testdata/draft2020-12/prefixItems.json +++ /dev/null @@ -1,104 +0,0 @@ -[ - { - "description": "a schema given for prefixItems", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - {"type": "integer"}, - {"type": "string"} - ] - }, - "tests": [ - { - "description": "correct types", - "data": [ 1, "foo" ], - "valid": true - }, - { - "description": "wrong types", - "data": [ "foo", 1 ], - "valid": false - }, - { - "description": "incomplete array of items", - "data": [ 1 ], - "valid": true - }, - { - "description": "array with additional items", - "data": [ 1, "foo", true ], - "valid": true - }, - { - "description": "empty array", - "data": [ ], - "valid": true - }, - { - "description": "JavaScript pseudo-array is valid", - "data": { - "0": "invalid", - "1": "valid", - "length": 2 - }, - "valid": true - } - ] - }, - { - "description": "prefixItems with boolean schemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [true, false] - }, - "tests": [ - { - "description": "array with one item is valid", - "data": [ 1 ], - "valid": true - }, - { - "description": "array with two items is invalid", - "data": [ 1, "foo" ], - "valid": false - }, - { - "description": "empty array is valid", - "data": [], - "valid": true - } - ] - }, - { - "description": "additional items are allowed by default", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [{"type": "integer"}] - }, - "tests": [ - { - "description": "only the first item is validated", - "data": [1, "foo", false], - "valid": true - } - ] - }, - { - "description": "prefixItems with null instance elements", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { - "type": "null" - } - ] - }, - "tests": [ - { - "description": "allows null elements", - "data": [ null ], - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/properties.json b/jsonschema/testdata/draft2020-12/properties.json deleted file mode 100644 index 523dcde7..00000000 --- a/jsonschema/testdata/draft2020-12/properties.json +++ /dev/null @@ -1,242 +0,0 @@ -[ - { - "description": "object properties validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {"type": "integer"}, - "bar": {"type": "string"} - } - }, - "tests": [ - { - "description": "both properties present and valid is valid", - "data": {"foo": 1, "bar": "baz"}, - "valid": true - }, - { - "description": "one property invalid is invalid", - "data": {"foo": 1, "bar": {}}, - "valid": false - }, - { - "description": "both properties invalid is invalid", - "data": {"foo": [], "bar": {}}, - "valid": false - }, - { - "description": "doesn't invalidate other properties", - "data": {"quux": []}, - "valid": true - }, - { - "description": "ignores arrays", - "data": [], - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": - "properties, patternProperties, additionalProperties interaction", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {"type": "array", "maxItems": 3}, - "bar": {"type": "array"} - }, - "patternProperties": {"f.o": {"minItems": 2}}, - "additionalProperties": {"type": "integer"} - }, - "tests": [ - { - "description": "property validates property", - "data": {"foo": [1, 2]}, - "valid": true - }, - { - "description": "property invalidates property", - "data": {"foo": [1, 2, 3, 4]}, - "valid": false - }, - { - "description": "patternProperty invalidates property", - "data": {"foo": []}, - "valid": false - }, - { - "description": "patternProperty validates nonproperty", - "data": {"fxo": [1, 2]}, - "valid": true - }, - { - "description": "patternProperty invalidates nonproperty", - "data": {"fxo": []}, - "valid": false - }, - { - "description": "additionalProperty ignores property", - "data": {"bar": []}, - "valid": true - }, - { - "description": "additionalProperty validates others", - "data": {"quux": 3}, - "valid": true - }, - { - "description": "additionalProperty invalidates others", - "data": {"quux": "foo"}, - "valid": false - } - ] - }, - { - "description": "properties with boolean schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": true, - "bar": false - } - }, - "tests": [ - { - "description": "no property present is valid", - "data": {}, - "valid": true - }, - { - "description": "only 'true' property present is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "only 'false' property present is invalid", - "data": {"bar": 2}, - "valid": false - }, - { - "description": "both properties present is invalid", - "data": {"foo": 1, "bar": 2}, - "valid": false - } - ] - }, - { - "description": "properties with escaped characters", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo\nbar": {"type": "number"}, - "foo\"bar": {"type": "number"}, - "foo\\bar": {"type": "number"}, - "foo\rbar": {"type": "number"}, - "foo\tbar": {"type": "number"}, - "foo\fbar": {"type": "number"} - } - }, - "tests": [ - { - "description": "object with all numbers is valid", - "data": { - "foo\nbar": 1, - "foo\"bar": 1, - "foo\\bar": 1, - "foo\rbar": 1, - "foo\tbar": 1, - "foo\fbar": 1 - }, - "valid": true - }, - { - "description": "object with strings is invalid", - "data": { - "foo\nbar": "1", - "foo\"bar": "1", - "foo\\bar": "1", - "foo\rbar": "1", - "foo\tbar": "1", - "foo\fbar": "1" - }, - "valid": false - } - ] - }, - { - "description": "properties with null valued instance properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {"type": "null"} - } - }, - "tests": [ - { - "description": "allows null values", - "data": {"foo": null}, - "valid": true - } - ] - }, - { - "description": "properties whose names are Javascript object property names", - "comment": "Ensure JS implementations don't universally consider e.g. __proto__ to always be present in an object.", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "__proto__": {"type": "number"}, - "toString": { - "properties": { "length": { "type": "string" } } - }, - "constructor": {"type": "number"} - } - }, - "tests": [ - { - "description": "ignores arrays", - "data": [], - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - }, - { - "description": "none of the properties mentioned", - "data": {}, - "valid": true - }, - { - "description": "__proto__ not valid", - "data": { "__proto__": "foo" }, - "valid": false - }, - { - "description": "toString not valid", - "data": { "toString": { "length": 37 } }, - "valid": false - }, - { - "description": "constructor not valid", - "data": { "constructor": { "length": 37 } }, - "valid": false - }, - { - "description": "all present and valid", - "data": { - "__proto__": 12, - "toString": { "length": "foo" }, - "constructor": 37 - }, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/propertyNames.json b/jsonschema/testdata/draft2020-12/propertyNames.json deleted file mode 100644 index b4780088..00000000 --- a/jsonschema/testdata/draft2020-12/propertyNames.json +++ /dev/null @@ -1,168 +0,0 @@ -[ - { - "description": "propertyNames validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": {"maxLength": 3} - }, - "tests": [ - { - "description": "all property names valid", - "data": { - "f": {}, - "foo": {} - }, - "valid": true - }, - { - "description": "some property names invalid", - "data": { - "foo": {}, - "foobar": {} - }, - "valid": false - }, - { - "description": "object without properties is valid", - "data": {}, - "valid": true - }, - { - "description": "ignores arrays", - "data": [1, 2, 3, 4], - "valid": true - }, - { - "description": "ignores strings", - "data": "foobar", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": "propertyNames validation with pattern", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": { "pattern": "^a+$" } - }, - "tests": [ - { - "description": "matching property names valid", - "data": { - "a": {}, - "aa": {}, - "aaa": {} - }, - "valid": true - }, - { - "description": "non-matching property name is invalid", - "data": { - "aaA": {} - }, - "valid": false - }, - { - "description": "object without properties is valid", - "data": {}, - "valid": true - } - ] - }, - { - "description": "propertyNames with boolean schema true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": true - }, - "tests": [ - { - "description": "object with any properties is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - } - ] - }, - { - "description": "propertyNames with boolean schema false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": false - }, - "tests": [ - { - "description": "object with any properties is invalid", - "data": {"foo": 1}, - "valid": false - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - } - ] - }, - { - "description": "propertyNames with const", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": {"const": "foo"} - }, - "tests": [ - { - "description": "object with property foo is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "object with any other property is invalid", - "data": {"bar": 1}, - "valid": false - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - } - ] - }, - { - "description": "propertyNames with enum", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": {"enum": ["foo", "bar"]} - }, - "tests": [ - { - "description": "object with property foo is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "object with property foo and bar is valid", - "data": {"foo": 1, "bar": 1}, - "valid": true - }, - { - "description": "object with any other property is invalid", - "data": {"baz": 1}, - "valid": false - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/ref.json b/jsonschema/testdata/draft2020-12/ref.json deleted file mode 100644 index 0ac02fb9..00000000 --- a/jsonschema/testdata/draft2020-12/ref.json +++ /dev/null @@ -1,1052 +0,0 @@ -[ - { - "description": "root pointer ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {"$ref": "#"} - }, - "additionalProperties": false - }, - "tests": [ - { - "description": "match", - "data": {"foo": false}, - "valid": true - }, - { - "description": "recursive match", - "data": {"foo": {"foo": false}}, - "valid": true - }, - { - "description": "mismatch", - "data": {"bar": false}, - "valid": false - }, - { - "description": "recursive mismatch", - "data": {"foo": {"bar": false}}, - "valid": false - } - ] - }, - { - "description": "relative pointer ref to object", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {"type": "integer"}, - "bar": {"$ref": "#/properties/foo"} - } - }, - "tests": [ - { - "description": "match", - "data": {"bar": 3}, - "valid": true - }, - { - "description": "mismatch", - "data": {"bar": true}, - "valid": false - } - ] - }, - { - "description": "relative pointer ref to array", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - {"type": "integer"}, - {"$ref": "#/prefixItems/0"} - ] - }, - "tests": [ - { - "description": "match array", - "data": [1, 2], - "valid": true - }, - { - "description": "mismatch array", - "data": [1, "foo"], - "valid": false - } - ] - }, - { - "description": "escaped pointer ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "tilde~field": {"type": "integer"}, - "slash/field": {"type": "integer"}, - "percent%field": {"type": "integer"} - }, - "properties": { - "tilde": {"$ref": "#/$defs/tilde~0field"}, - "slash": {"$ref": "#/$defs/slash~1field"}, - "percent": {"$ref": "#/$defs/percent%25field"} - } - }, - "tests": [ - { - "description": "slash invalid", - "data": {"slash": "aoeu"}, - "valid": false - }, - { - "description": "tilde invalid", - "data": {"tilde": "aoeu"}, - "valid": false - }, - { - "description": "percent invalid", - "data": {"percent": "aoeu"}, - "valid": false - }, - { - "description": "slash valid", - "data": {"slash": 123}, - "valid": true - }, - { - "description": "tilde valid", - "data": {"tilde": 123}, - "valid": true - }, - { - "description": "percent valid", - "data": {"percent": 123}, - "valid": true - } - ] - }, - { - "description": "nested refs", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "a": {"type": "integer"}, - "b": {"$ref": "#/$defs/a"}, - "c": {"$ref": "#/$defs/b"} - }, - "$ref": "#/$defs/c" - }, - "tests": [ - { - "description": "nested ref valid", - "data": 5, - "valid": true - }, - { - "description": "nested ref invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "ref applies alongside sibling keywords", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "reffed": { - "type": "array" - } - }, - "properties": { - "foo": { - "$ref": "#/$defs/reffed", - "maxItems": 2 - } - } - }, - "tests": [ - { - "description": "ref valid, maxItems valid", - "data": { "foo": [] }, - "valid": true - }, - { - "description": "ref valid, maxItems invalid", - "data": { "foo": [1, 2, 3] }, - "valid": false - }, - { - "description": "ref invalid", - "data": { "foo": "string" }, - "valid": false - } - ] - }, - { - "description": "remote ref, containing refs itself", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "https://json-schema.org/draft/2020-12/schema" - }, - "tests": [ - { - "description": "remote ref valid", - "data": {"minLength": 1}, - "valid": true - }, - { - "description": "remote ref invalid", - "data": {"minLength": -1}, - "valid": false - } - ] - }, - { - "description": "property named $ref that is not a reference", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "$ref": {"type": "string"} - } - }, - "tests": [ - { - "description": "property named $ref valid", - "data": {"$ref": "a"}, - "valid": true - }, - { - "description": "property named $ref invalid", - "data": {"$ref": 2}, - "valid": false - } - ] - }, - { - "description": "property named $ref, containing an actual $ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "$ref": {"$ref": "#/$defs/is-string"} - }, - "$defs": { - "is-string": { - "type": "string" - } - } - }, - "tests": [ - { - "description": "property named $ref valid", - "data": {"$ref": "a"}, - "valid": true - }, - { - "description": "property named $ref invalid", - "data": {"$ref": 2}, - "valid": false - } - ] - }, - { - "description": "$ref to boolean schema true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "#/$defs/bool", - "$defs": { - "bool": true - } - }, - "tests": [ - { - "description": "any value is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "$ref to boolean schema false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "#/$defs/bool", - "$defs": { - "bool": false - } - }, - "tests": [ - { - "description": "any value is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "Recursive references between schemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/tree", - "description": "tree of nodes", - "type": "object", - "properties": { - "meta": {"type": "string"}, - "nodes": { - "type": "array", - "items": {"$ref": "node"} - } - }, - "required": ["meta", "nodes"], - "$defs": { - "node": { - "$id": "http://localhost:1234/draft2020-12/node", - "description": "node", - "type": "object", - "properties": { - "value": {"type": "number"}, - "subtree": {"$ref": "tree"} - }, - "required": ["value"] - } - } - }, - "tests": [ - { - "description": "valid tree", - "data": { - "meta": "root", - "nodes": [ - { - "value": 1, - "subtree": { - "meta": "child", - "nodes": [ - {"value": 1.1}, - {"value": 1.2} - ] - } - }, - { - "value": 2, - "subtree": { - "meta": "child", - "nodes": [ - {"value": 2.1}, - {"value": 2.2} - ] - } - } - ] - }, - "valid": true - }, - { - "description": "invalid tree", - "data": { - "meta": "root", - "nodes": [ - { - "value": 1, - "subtree": { - "meta": "child", - "nodes": [ - {"value": "string is invalid"}, - {"value": 1.2} - ] - } - }, - { - "value": 2, - "subtree": { - "meta": "child", - "nodes": [ - {"value": 2.1}, - {"value": 2.2} - ] - } - } - ] - }, - "valid": false - } - ] - }, - { - "description": "refs with quote", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo\"bar": {"$ref": "#/$defs/foo%22bar"} - }, - "$defs": { - "foo\"bar": {"type": "number"} - } - }, - "tests": [ - { - "description": "object with numbers is valid", - "data": { - "foo\"bar": 1 - }, - "valid": true - }, - { - "description": "object with strings is invalid", - "data": { - "foo\"bar": "1" - }, - "valid": false - } - ] - }, - { - "description": "ref creates new scope when adjacent to keywords", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "A": { - "unevaluatedProperties": false - } - }, - "properties": { - "prop1": { - "type": "string" - } - }, - "$ref": "#/$defs/A" - }, - "tests": [ - { - "description": "referenced subschema doesn't see annotations from properties", - "data": { - "prop1": "match" - }, - "valid": false - } - ] - }, - { - "description": "naive replacement of $ref with its destination is not correct", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "a_string": { "type": "string" } - }, - "enum": [ - { "$ref": "#/$defs/a_string" } - ] - }, - "tests": [ - { - "description": "do not evaluate the $ref inside the enum, matching any string", - "data": "this is a string", - "valid": false - }, - { - "description": "do not evaluate the $ref inside the enum, definition exact match", - "data": { "type": "string" }, - "valid": false - }, - { - "description": "match the enum exactly", - "data": { "$ref": "#/$defs/a_string" }, - "valid": true - } - ] - }, - { - "description": "refs with relative uris and defs", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://example.com/schema-relative-uri-defs1.json", - "properties": { - "foo": { - "$id": "schema-relative-uri-defs2.json", - "$defs": { - "inner": { - "properties": { - "bar": { "type": "string" } - } - } - }, - "$ref": "#/$defs/inner" - } - }, - "$ref": "schema-relative-uri-defs2.json" - }, - "tests": [ - { - "description": "invalid on inner field", - "data": { - "foo": { - "bar": 1 - }, - "bar": "a" - }, - "valid": false - }, - { - "description": "invalid on outer field", - "data": { - "foo": { - "bar": "a" - }, - "bar": 1 - }, - "valid": false - }, - { - "description": "valid on both fields", - "data": { - "foo": { - "bar": "a" - }, - "bar": "a" - }, - "valid": true - } - ] - }, - { - "description": "relative refs with absolute uris and defs", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://example.com/schema-refs-absolute-uris-defs1.json", - "properties": { - "foo": { - "$id": "http://example.com/schema-refs-absolute-uris-defs2.json", - "$defs": { - "inner": { - "properties": { - "bar": { "type": "string" } - } - } - }, - "$ref": "#/$defs/inner" - } - }, - "$ref": "schema-refs-absolute-uris-defs2.json" - }, - "tests": [ - { - "description": "invalid on inner field", - "data": { - "foo": { - "bar": 1 - }, - "bar": "a" - }, - "valid": false - }, - { - "description": "invalid on outer field", - "data": { - "foo": { - "bar": "a" - }, - "bar": 1 - }, - "valid": false - }, - { - "description": "valid on both fields", - "data": { - "foo": { - "bar": "a" - }, - "bar": "a" - }, - "valid": true - } - ] - }, - { - "description": "$id must be resolved against nearest parent, not just immediate parent", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://example.com/a.json", - "$defs": { - "x": { - "$id": "http://example.com/b/c.json", - "not": { - "$defs": { - "y": { - "$id": "d.json", - "type": "number" - } - } - } - } - }, - "allOf": [ - { - "$ref": "http://example.com/b/d.json" - } - ] - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "non-number is invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "order of evaluation: $id and $ref", - "schema": { - "$comment": "$id must be evaluated before $ref to get the proper $ref destination", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://example.com/draft2020-12/ref-and-id1/base.json", - "$ref": "int.json", - "$defs": { - "bigint": { - "$comment": "canonical uri: https://example.com/ref-and-id1/int.json", - "$id": "int.json", - "maximum": 10 - }, - "smallint": { - "$comment": "canonical uri: https://example.com/ref-and-id1-int.json", - "$id": "/draft2020-12/ref-and-id1-int.json", - "maximum": 2 - } - } - }, - "tests": [ - { - "description": "data is valid against first definition", - "data": 5, - "valid": true - }, - { - "description": "data is invalid against first definition", - "data": 50, - "valid": false - } - ] - }, - { - "description": "order of evaluation: $id and $anchor and $ref", - "schema": { - "$comment": "$id must be evaluated before $ref to get the proper $ref destination", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://example.com/draft2020-12/ref-and-id2/base.json", - "$ref": "#bigint", - "$defs": { - "bigint": { - "$comment": "canonical uri: /ref-and-id2/base.json#/$defs/bigint; another valid uri for this location: /ref-and-id2/base.json#bigint", - "$anchor": "bigint", - "maximum": 10 - }, - "smallint": { - "$comment": "canonical uri: https://example.com/ref-and-id2#/$defs/smallint; another valid uri for this location: https://example.com/ref-and-id2/#bigint", - "$id": "https://example.com/draft2020-12/ref-and-id2/", - "$anchor": "bigint", - "maximum": 2 - } - } - }, - "tests": [ - { - "description": "data is valid against first definition", - "data": 5, - "valid": true - }, - { - "description": "data is invalid against first definition", - "data": 50, - "valid": false - } - ] - }, - { - "description": "simple URN base URI with $ref via the URN", - "schema": { - "$comment": "URIs do not have to have HTTP(s) schemes", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "urn:uuid:deadbeef-1234-ffff-ffff-4321feebdaed", - "minimum": 30, - "properties": { - "foo": {"$ref": "urn:uuid:deadbeef-1234-ffff-ffff-4321feebdaed"} - } - }, - "tests": [ - { - "description": "valid under the URN IDed schema", - "data": {"foo": 37}, - "valid": true - }, - { - "description": "invalid under the URN IDed schema", - "data": {"foo": 12}, - "valid": false - } - ] - }, - { - "description": "simple URN base URI with JSON pointer", - "schema": { - "$comment": "URIs do not have to have HTTP(s) schemes", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "urn:uuid:deadbeef-1234-00ff-ff00-4321feebdaed", - "properties": { - "foo": {"$ref": "#/$defs/bar"} - }, - "$defs": { - "bar": {"type": "string"} - } - }, - "tests": [ - { - "description": "a string is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "a non-string is invalid", - "data": {"foo": 12}, - "valid": false - } - ] - }, - { - "description": "URN base URI with NSS", - "schema": { - "$comment": "RFC 8141 §2.2", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "urn:example:1/406/47452/2", - "properties": { - "foo": {"$ref": "#/$defs/bar"} - }, - "$defs": { - "bar": {"type": "string"} - } - }, - "tests": [ - { - "description": "a string is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "a non-string is invalid", - "data": {"foo": 12}, - "valid": false - } - ] - }, - { - "description": "URN base URI with r-component", - "schema": { - "$comment": "RFC 8141 §2.3.1", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "urn:example:foo-bar-baz-qux?+CCResolve:cc=uk", - "properties": { - "foo": {"$ref": "#/$defs/bar"} - }, - "$defs": { - "bar": {"type": "string"} - } - }, - "tests": [ - { - "description": "a string is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "a non-string is invalid", - "data": {"foo": 12}, - "valid": false - } - ] - }, - { - "description": "URN base URI with q-component", - "schema": { - "$comment": "RFC 8141 §2.3.2", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "urn:example:weather?=op=map&lat=39.56&lon=-104.85&datetime=1969-07-21T02:56:15Z", - "properties": { - "foo": {"$ref": "#/$defs/bar"} - }, - "$defs": { - "bar": {"type": "string"} - } - }, - "tests": [ - { - "description": "a string is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "a non-string is invalid", - "data": {"foo": 12}, - "valid": false - } - ] - }, - { - "description": "URN base URI with URN and JSON pointer ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "urn:uuid:deadbeef-1234-0000-0000-4321feebdaed", - "properties": { - "foo": {"$ref": "urn:uuid:deadbeef-1234-0000-0000-4321feebdaed#/$defs/bar"} - }, - "$defs": { - "bar": {"type": "string"} - } - }, - "tests": [ - { - "description": "a string is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "a non-string is invalid", - "data": {"foo": 12}, - "valid": false - } - ] - }, - { - "description": "URN base URI with URN and anchor ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "urn:uuid:deadbeef-1234-ff00-00ff-4321feebdaed", - "properties": { - "foo": {"$ref": "urn:uuid:deadbeef-1234-ff00-00ff-4321feebdaed#something"} - }, - "$defs": { - "bar": { - "$anchor": "something", - "type": "string" - } - } - }, - "tests": [ - { - "description": "a string is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "a non-string is invalid", - "data": {"foo": 12}, - "valid": false - } - ] - }, - { - "description": "URN ref with nested pointer ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "urn:uuid:deadbeef-4321-ffff-ffff-1234feebdaed", - "$defs": { - "foo": { - "$id": "urn:uuid:deadbeef-4321-ffff-ffff-1234feebdaed", - "$defs": {"bar": {"type": "string"}}, - "$ref": "#/$defs/bar" - } - } - }, - "tests": [ - { - "description": "a string is valid", - "data": "bar", - "valid": true - }, - { - "description": "a non-string is invalid", - "data": 12, - "valid": false - } - ] - }, - { - "description": "ref to if", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://example.com/ref/if", - "if": { - "$id": "http://example.com/ref/if", - "type": "integer" - } - }, - "tests": [ - { - "description": "a non-integer is invalid due to the $ref", - "data": "foo", - "valid": false - }, - { - "description": "an integer is valid", - "data": 12, - "valid": true - } - ] - }, - { - "description": "ref to then", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://example.com/ref/then", - "then": { - "$id": "http://example.com/ref/then", - "type": "integer" - } - }, - "tests": [ - { - "description": "a non-integer is invalid due to the $ref", - "data": "foo", - "valid": false - }, - { - "description": "an integer is valid", - "data": 12, - "valid": true - } - ] - }, - { - "description": "ref to else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://example.com/ref/else", - "else": { - "$id": "http://example.com/ref/else", - "type": "integer" - } - }, - "tests": [ - { - "description": "a non-integer is invalid due to the $ref", - "data": "foo", - "valid": false - }, - { - "description": "an integer is valid", - "data": 12, - "valid": true - } - ] - }, - { - "description": "ref with absolute-path-reference", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://example.com/ref/absref.json", - "$defs": { - "a": { - "$id": "http://example.com/ref/absref/foobar.json", - "type": "number" - }, - "b": { - "$id": "http://example.com/absref/foobar.json", - "type": "string" - } - }, - "$ref": "/absref/foobar.json" - }, - "tests": [ - { - "description": "a string is valid", - "data": "foo", - "valid": true - }, - { - "description": "an integer is invalid", - "data": 12, - "valid": false - } - ] - }, - { - "description": "$id with file URI still resolves pointers - *nix", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "file:///folder/file.json", - "$defs": { - "foo": { - "type": "number" - } - }, - "$ref": "#/$defs/foo" - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "non-number is invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "$id with file URI still resolves pointers - windows", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "file:///c:/folder/file.json", - "$defs": { - "foo": { - "type": "number" - } - }, - "$ref": "#/$defs/foo" - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "non-number is invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "empty tokens in $ref json-pointer", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "": { - "$defs": { - "": { "type": "number" } - } - } - }, - "allOf": [ - { - "$ref": "#/$defs//$defs/" - } - ] - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "non-number is invalid", - "data": "a", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/refRemote.json b/jsonschema/testdata/draft2020-12/refRemote.json deleted file mode 100644 index 047ac74c..00000000 --- a/jsonschema/testdata/draft2020-12/refRemote.json +++ /dev/null @@ -1,342 +0,0 @@ -[ - { - "description": "remote ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/draft2020-12/integer.json" - }, - "tests": [ - { - "description": "remote ref valid", - "data": 1, - "valid": true - }, - { - "description": "remote ref invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "fragment within remote ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/draft2020-12/subSchemas.json#/$defs/integer" - }, - "tests": [ - { - "description": "remote fragment valid", - "data": 1, - "valid": true - }, - { - "description": "remote fragment invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "anchor within remote ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/draft2020-12/locationIndependentIdentifier.json#foo" - }, - "tests": [ - { - "description": "remote anchor valid", - "data": 1, - "valid": true - }, - { - "description": "remote anchor invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "ref within remote ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/draft2020-12/subSchemas.json#/$defs/refToInteger" - }, - "tests": [ - { - "description": "ref within ref valid", - "data": 1, - "valid": true - }, - { - "description": "ref within ref invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "base URI change", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/", - "items": { - "$id": "baseUriChange/", - "items": {"$ref": "folderInteger.json"} - } - }, - "tests": [ - { - "description": "base URI change ref valid", - "data": [[1]], - "valid": true - }, - { - "description": "base URI change ref invalid", - "data": [["a"]], - "valid": false - } - ] - }, - { - "description": "base URI change - change folder", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/scope_change_defs1.json", - "type" : "object", - "properties": {"list": {"$ref": "baseUriChangeFolder/"}}, - "$defs": { - "baz": { - "$id": "baseUriChangeFolder/", - "type": "array", - "items": {"$ref": "folderInteger.json"} - } - } - }, - "tests": [ - { - "description": "number is valid", - "data": {"list": [1]}, - "valid": true - }, - { - "description": "string is invalid", - "data": {"list": ["a"]}, - "valid": false - } - ] - }, - { - "description": "base URI change - change folder in subschema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/scope_change_defs2.json", - "type" : "object", - "properties": {"list": {"$ref": "baseUriChangeFolderInSubschema/#/$defs/bar"}}, - "$defs": { - "baz": { - "$id": "baseUriChangeFolderInSubschema/", - "$defs": { - "bar": { - "type": "array", - "items": {"$ref": "folderInteger.json"} - } - } - } - } - }, - "tests": [ - { - "description": "number is valid", - "data": {"list": [1]}, - "valid": true - }, - { - "description": "string is invalid", - "data": {"list": ["a"]}, - "valid": false - } - ] - }, - { - "description": "root ref in remote ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/object", - "type": "object", - "properties": { - "name": {"$ref": "name-defs.json#/$defs/orNull"} - } - }, - "tests": [ - { - "description": "string is valid", - "data": { - "name": "foo" - }, - "valid": true - }, - { - "description": "null is valid", - "data": { - "name": null - }, - "valid": true - }, - { - "description": "object is invalid", - "data": { - "name": { - "name": null - } - }, - "valid": false - } - ] - }, - { - "description": "remote ref with ref to defs", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/schema-remote-ref-ref-defs1.json", - "$ref": "ref-and-defs.json" - }, - "tests": [ - { - "description": "invalid", - "data": { - "bar": 1 - }, - "valid": false - }, - { - "description": "valid", - "data": { - "bar": "a" - }, - "valid": true - } - ] - }, - { - "description": "Location-independent identifier in remote ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/draft2020-12/locationIndependentIdentifier.json#/$defs/refToInteger" - }, - "tests": [ - { - "description": "integer is valid", - "data": 1, - "valid": true - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "retrieved nested refs resolve relative to their URI not $id", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/some-id", - "properties": { - "name": {"$ref": "nested/foo-ref-string.json"} - } - }, - "tests": [ - { - "description": "number is invalid", - "data": { - "name": {"foo": 1} - }, - "valid": false - }, - { - "description": "string is valid", - "data": { - "name": {"foo": "a"} - }, - "valid": true - } - ] - }, - { - "description": "remote HTTP ref with different $id", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/different-id-ref-string.json" - }, - "tests": [ - { - "description": "number is invalid", - "data": 1, - "valid": false - }, - { - "description": "string is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "remote HTTP ref with different URN $id", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/urn-ref-string.json" - }, - "tests": [ - { - "description": "number is invalid", - "data": 1, - "valid": false - }, - { - "description": "string is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "remote HTTP ref with nested absolute ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/nested-absolute-ref-to-string.json" - }, - "tests": [ - { - "description": "number is invalid", - "data": 1, - "valid": false - }, - { - "description": "string is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "$ref to $ref finds detached $anchor", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/draft2020-12/detached-ref.json#/$defs/foo" - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "non-number is invalid", - "data": "a", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/required.json b/jsonschema/testdata/draft2020-12/required.json deleted file mode 100644 index e66f29f2..00000000 --- a/jsonschema/testdata/draft2020-12/required.json +++ /dev/null @@ -1,158 +0,0 @@ -[ - { - "description": "required validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {}, - "bar": {} - }, - "required": ["foo"] - }, - "tests": [ - { - "description": "present required property is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "non-present required property is invalid", - "data": {"bar": 1}, - "valid": false - }, - { - "description": "ignores arrays", - "data": [], - "valid": true - }, - { - "description": "ignores strings", - "data": "", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": "required default validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {} - } - }, - "tests": [ - { - "description": "not required by default", - "data": {}, - "valid": true - } - ] - }, - { - "description": "required with empty array", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {} - }, - "required": [] - }, - "tests": [ - { - "description": "property not required", - "data": {}, - "valid": true - } - ] - }, - { - "description": "required with escaped characters", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "required": [ - "foo\nbar", - "foo\"bar", - "foo\\bar", - "foo\rbar", - "foo\tbar", - "foo\fbar" - ] - }, - "tests": [ - { - "description": "object with all properties present is valid", - "data": { - "foo\nbar": 1, - "foo\"bar": 1, - "foo\\bar": 1, - "foo\rbar": 1, - "foo\tbar": 1, - "foo\fbar": 1 - }, - "valid": true - }, - { - "description": "object with some properties missing is invalid", - "data": { - "foo\nbar": "1", - "foo\"bar": "1" - }, - "valid": false - } - ] - }, - { - "description": "required properties whose names are Javascript object property names", - "comment": "Ensure JS implementations don't universally consider e.g. __proto__ to always be present in an object.", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "required": ["__proto__", "toString", "constructor"] - }, - "tests": [ - { - "description": "ignores arrays", - "data": [], - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - }, - { - "description": "none of the properties mentioned", - "data": {}, - "valid": false - }, - { - "description": "__proto__ present", - "data": { "__proto__": "foo" }, - "valid": false - }, - { - "description": "toString present", - "data": { "toString": { "length": 37 } }, - "valid": false - }, - { - "description": "constructor present", - "data": { "constructor": { "length": 37 } }, - "valid": false - }, - { - "description": "all present", - "data": { - "__proto__": 12, - "toString": { "length": "foo" }, - "constructor": 37 - }, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/type.json b/jsonschema/testdata/draft2020-12/type.json deleted file mode 100644 index 2123c408..00000000 --- a/jsonschema/testdata/draft2020-12/type.json +++ /dev/null @@ -1,501 +0,0 @@ -[ - { - "description": "integer type matches integers", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "integer" - }, - "tests": [ - { - "description": "an integer is an integer", - "data": 1, - "valid": true - }, - { - "description": "a float with zero fractional part is an integer", - "data": 1.0, - "valid": true - }, - { - "description": "a float is not an integer", - "data": 1.1, - "valid": false - }, - { - "description": "a string is not an integer", - "data": "foo", - "valid": false - }, - { - "description": "a string is still not an integer, even if it looks like one", - "data": "1", - "valid": false - }, - { - "description": "an object is not an integer", - "data": {}, - "valid": false - }, - { - "description": "an array is not an integer", - "data": [], - "valid": false - }, - { - "description": "a boolean is not an integer", - "data": true, - "valid": false - }, - { - "description": "null is not an integer", - "data": null, - "valid": false - } - ] - }, - { - "description": "number type matches numbers", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "number" - }, - "tests": [ - { - "description": "an integer is a number", - "data": 1, - "valid": true - }, - { - "description": "a float with zero fractional part is a number (and an integer)", - "data": 1.0, - "valid": true - }, - { - "description": "a float is a number", - "data": 1.1, - "valid": true - }, - { - "description": "a string is not a number", - "data": "foo", - "valid": false - }, - { - "description": "a string is still not a number, even if it looks like one", - "data": "1", - "valid": false - }, - { - "description": "an object is not a number", - "data": {}, - "valid": false - }, - { - "description": "an array is not a number", - "data": [], - "valid": false - }, - { - "description": "a boolean is not a number", - "data": true, - "valid": false - }, - { - "description": "null is not a number", - "data": null, - "valid": false - } - ] - }, - { - "description": "string type matches strings", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "string" - }, - "tests": [ - { - "description": "1 is not a string", - "data": 1, - "valid": false - }, - { - "description": "a float is not a string", - "data": 1.1, - "valid": false - }, - { - "description": "a string is a string", - "data": "foo", - "valid": true - }, - { - "description": "a string is still a string, even if it looks like a number", - "data": "1", - "valid": true - }, - { - "description": "an empty string is still a string", - "data": "", - "valid": true - }, - { - "description": "an object is not a string", - "data": {}, - "valid": false - }, - { - "description": "an array is not a string", - "data": [], - "valid": false - }, - { - "description": "a boolean is not a string", - "data": true, - "valid": false - }, - { - "description": "null is not a string", - "data": null, - "valid": false - } - ] - }, - { - "description": "object type matches objects", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object" - }, - "tests": [ - { - "description": "an integer is not an object", - "data": 1, - "valid": false - }, - { - "description": "a float is not an object", - "data": 1.1, - "valid": false - }, - { - "description": "a string is not an object", - "data": "foo", - "valid": false - }, - { - "description": "an object is an object", - "data": {}, - "valid": true - }, - { - "description": "an array is not an object", - "data": [], - "valid": false - }, - { - "description": "a boolean is not an object", - "data": true, - "valid": false - }, - { - "description": "null is not an object", - "data": null, - "valid": false - } - ] - }, - { - "description": "array type matches arrays", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "array" - }, - "tests": [ - { - "description": "an integer is not an array", - "data": 1, - "valid": false - }, - { - "description": "a float is not an array", - "data": 1.1, - "valid": false - }, - { - "description": "a string is not an array", - "data": "foo", - "valid": false - }, - { - "description": "an object is not an array", - "data": {}, - "valid": false - }, - { - "description": "an array is an array", - "data": [], - "valid": true - }, - { - "description": "a boolean is not an array", - "data": true, - "valid": false - }, - { - "description": "null is not an array", - "data": null, - "valid": false - } - ] - }, - { - "description": "boolean type matches booleans", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "boolean" - }, - "tests": [ - { - "description": "an integer is not a boolean", - "data": 1, - "valid": false - }, - { - "description": "zero is not a boolean", - "data": 0, - "valid": false - }, - { - "description": "a float is not a boolean", - "data": 1.1, - "valid": false - }, - { - "description": "a string is not a boolean", - "data": "foo", - "valid": false - }, - { - "description": "an empty string is not a boolean", - "data": "", - "valid": false - }, - { - "description": "an object is not a boolean", - "data": {}, - "valid": false - }, - { - "description": "an array is not a boolean", - "data": [], - "valid": false - }, - { - "description": "true is a boolean", - "data": true, - "valid": true - }, - { - "description": "false is a boolean", - "data": false, - "valid": true - }, - { - "description": "null is not a boolean", - "data": null, - "valid": false - } - ] - }, - { - "description": "null type matches only the null object", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "null" - }, - "tests": [ - { - "description": "an integer is not null", - "data": 1, - "valid": false - }, - { - "description": "a float is not null", - "data": 1.1, - "valid": false - }, - { - "description": "zero is not null", - "data": 0, - "valid": false - }, - { - "description": "a string is not null", - "data": "foo", - "valid": false - }, - { - "description": "an empty string is not null", - "data": "", - "valid": false - }, - { - "description": "an object is not null", - "data": {}, - "valid": false - }, - { - "description": "an array is not null", - "data": [], - "valid": false - }, - { - "description": "true is not null", - "data": true, - "valid": false - }, - { - "description": "false is not null", - "data": false, - "valid": false - }, - { - "description": "null is null", - "data": null, - "valid": true - } - ] - }, - { - "description": "multiple types can be specified in an array", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": ["integer", "string"] - }, - "tests": [ - { - "description": "an integer is valid", - "data": 1, - "valid": true - }, - { - "description": "a string is valid", - "data": "foo", - "valid": true - }, - { - "description": "a float is invalid", - "data": 1.1, - "valid": false - }, - { - "description": "an object is invalid", - "data": {}, - "valid": false - }, - { - "description": "an array is invalid", - "data": [], - "valid": false - }, - { - "description": "a boolean is invalid", - "data": true, - "valid": false - }, - { - "description": "null is invalid", - "data": null, - "valid": false - } - ] - }, - { - "description": "type as array with one item", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": ["string"] - }, - "tests": [ - { - "description": "string is valid", - "data": "foo", - "valid": true - }, - { - "description": "number is invalid", - "data": 123, - "valid": false - } - ] - }, - { - "description": "type: array or object", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": ["array", "object"] - }, - "tests": [ - { - "description": "array is valid", - "data": [1,2,3], - "valid": true - }, - { - "description": "object is valid", - "data": {"foo": 123}, - "valid": true - }, - { - "description": "number is invalid", - "data": 123, - "valid": false - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - }, - { - "description": "null is invalid", - "data": null, - "valid": false - } - ] - }, - { - "description": "type: array, object or null", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": ["array", "object", "null"] - }, - "tests": [ - { - "description": "array is valid", - "data": [1,2,3], - "valid": true - }, - { - "description": "object is valid", - "data": {"foo": 123}, - "valid": true - }, - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "number is invalid", - "data": 123, - "valid": false - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/unevaluatedItems.json b/jsonschema/testdata/draft2020-12/unevaluatedItems.json deleted file mode 100644 index f861cefa..00000000 --- a/jsonschema/testdata/draft2020-12/unevaluatedItems.json +++ /dev/null @@ -1,798 +0,0 @@ -[ - { - "description": "unevaluatedItems true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedItems": true - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": [], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo"], - "valid": true - } - ] - }, - { - "description": "unevaluatedItems false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": [], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems as schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedItems": { "type": "string" } - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": [], - "valid": true - }, - { - "description": "with valid unevaluated items", - "data": ["foo"], - "valid": true - }, - { - "description": "with invalid unevaluated items", - "data": [42], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with uniform items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "items": { "type": "string" }, - "unevaluatedItems": false - }, - "tests": [ - { - "description": "unevaluatedItems doesn't apply", - "data": ["foo", "bar"], - "valid": true - } - ] - }, - { - "description": "unevaluatedItems with tuple", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { "type": "string" } - ], - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": ["foo"], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo", "bar"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with items and prefixItems", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { "type": "string" } - ], - "items": true, - "unevaluatedItems": false - }, - "tests": [ - { - "description": "unevaluatedItems doesn't apply", - "data": ["foo", 42], - "valid": true - } - ] - }, - { - "description": "unevaluatedItems with items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "items": {"type": "number"}, - "unevaluatedItems": {"type": "string"} - }, - "tests": [ - { - "description": "valid under items", - "comment": "no elements are considered by unevaluatedItems", - "data": [5, 6, 7, 8], - "valid": true - }, - { - "description": "invalid under items", - "data": ["foo", "bar", "baz"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with nested tuple", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { "type": "string" } - ], - "allOf": [ - { - "prefixItems": [ - true, - { "type": "number" } - ] - } - ], - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": ["foo", 42], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo", 42, true], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with nested items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedItems": {"type": "boolean"}, - "anyOf": [ - { "items": {"type": "string"} }, - true - ] - }, - "tests": [ - { - "description": "with only (valid) additional items", - "data": [true, false], - "valid": true - }, - { - "description": "with no additional items", - "data": ["yes", "no"], - "valid": true - }, - { - "description": "with invalid additional item", - "data": ["yes", false], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with nested prefixItems and items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "prefixItems": [ - { "type": "string" } - ], - "items": true - } - ], - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with no additional items", - "data": ["foo"], - "valid": true - }, - { - "description": "with additional items", - "data": ["foo", 42, true], - "valid": true - } - ] - }, - { - "description": "unevaluatedItems with nested unevaluatedItems", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "prefixItems": [ - { "type": "string" } - ] - }, - { "unevaluatedItems": true } - ], - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with no additional items", - "data": ["foo"], - "valid": true - }, - { - "description": "with additional items", - "data": ["foo", 42, true], - "valid": true - } - ] - }, - { - "description": "unevaluatedItems with anyOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { "const": "foo" } - ], - "anyOf": [ - { - "prefixItems": [ - true, - { "const": "bar" } - ] - }, - { - "prefixItems": [ - true, - true, - { "const": "baz" } - ] - } - ], - "unevaluatedItems": false - }, - "tests": [ - { - "description": "when one schema matches and has no unevaluated items", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "when one schema matches and has unevaluated items", - "data": ["foo", "bar", 42], - "valid": false - }, - { - "description": "when two schemas match and has no unevaluated items", - "data": ["foo", "bar", "baz"], - "valid": true - }, - { - "description": "when two schemas match and has unevaluated items", - "data": ["foo", "bar", "baz", 42], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with oneOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { "const": "foo" } - ], - "oneOf": [ - { - "prefixItems": [ - true, - { "const": "bar" } - ] - }, - { - "prefixItems": [ - true, - { "const": "baz" } - ] - } - ], - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo", "bar", 42], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with not", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { "const": "foo" } - ], - "not": { - "not": { - "prefixItems": [ - true, - { "const": "bar" } - ] - } - }, - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with unevaluated items", - "data": ["foo", "bar"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with if/then/else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { "const": "foo" } - ], - "if": { - "prefixItems": [ - true, - { "const": "bar" } - ] - }, - "then": { - "prefixItems": [ - true, - true, - { "const": "then" } - ] - }, - "else": { - "prefixItems": [ - true, - true, - true, - { "const": "else" } - ] - }, - "unevaluatedItems": false - }, - "tests": [ - { - "description": "when if matches and it has no unevaluated items", - "data": ["foo", "bar", "then"], - "valid": true - }, - { - "description": "when if matches and it has unevaluated items", - "data": ["foo", "bar", "then", "else"], - "valid": false - }, - { - "description": "when if doesn't match and it has no unevaluated items", - "data": ["foo", 42, 42, "else"], - "valid": true - }, - { - "description": "when if doesn't match and it has unevaluated items", - "data": ["foo", 42, 42, "else", 42], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with boolean schemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [true], - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": [], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with $ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "#/$defs/bar", - "prefixItems": [ - { "type": "string" } - ], - "unevaluatedItems": false, - "$defs": { - "bar": { - "prefixItems": [ - true, - { "type": "string" } - ] - } - } - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo", "bar", "baz"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems before $ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedItems": false, - "prefixItems": [ - { "type": "string" } - ], - "$ref": "#/$defs/bar", - "$defs": { - "bar": { - "prefixItems": [ - true, - { "type": "string" } - ] - } - } - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo", "bar", "baz"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with $dynamicRef", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://example.com/unevaluated-items-with-dynamic-ref/derived", - - "$ref": "./baseSchema", - - "$defs": { - "derived": { - "$dynamicAnchor": "addons", - "prefixItems": [ - true, - { "type": "string" } - ] - }, - "baseSchema": { - "$id": "./baseSchema", - - "$comment": "unevaluatedItems comes first so it's more likely to catch bugs with implementations that are sensitive to keyword ordering", - "unevaluatedItems": false, - "type": "array", - "prefixItems": [ - { "type": "string" } - ], - "$dynamicRef": "#addons", - - "$defs": { - "defaultAddons": { - "$comment": "Needed to satisfy the bookending requirement", - "$dynamicAnchor": "addons" - } - } - } - } - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo", "bar", "baz"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems can't see inside cousins", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "prefixItems": [ true ] - }, - { "unevaluatedItems": false } - ] - }, - "tests": [ - { - "description": "always fails", - "data": [ 1 ], - "valid": false - } - ] - }, - { - "description": "item is evaluated in an uncle schema to unevaluatedItems", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": { - "prefixItems": [ - { "type": "string" } - ], - "unevaluatedItems": false - } - }, - "anyOf": [ - { - "properties": { - "foo": { - "prefixItems": [ - true, - { "type": "string" } - ] - } - } - } - ] - }, - "tests": [ - { - "description": "no extra items", - "data": { - "foo": [ - "test" - ] - }, - "valid": true - }, - { - "description": "uncle keyword evaluation is not significant", - "data": { - "foo": [ - "test", - "test" - ] - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedItems depends on adjacent contains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [true], - "contains": {"type": "string"}, - "unevaluatedItems": false - }, - "tests": [ - { - "description": "second item is evaluated by contains", - "data": [ 1, "foo" ], - "valid": true - }, - { - "description": "contains fails, second item is not evaluated", - "data": [ 1, 2 ], - "valid": false - }, - { - "description": "contains passes, second item is not evaluated", - "data": [ 1, 2, "foo" ], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems depends on multiple nested contains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { "contains": { "multipleOf": 2 } }, - { "contains": { "multipleOf": 3 } } - ], - "unevaluatedItems": { "multipleOf": 5 } - }, - "tests": [ - { - "description": "5 not evaluated, passes unevaluatedItems", - "data": [ 2, 3, 4, 5, 6 ], - "valid": true - }, - { - "description": "7 not evaluated, fails unevaluatedItems", - "data": [ 2, 3, 4, 7, 8 ], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems and contains interact to control item dependency relationship", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": { - "contains": {"const": "a"} - }, - "then": { - "if": { - "contains": {"const": "b"} - }, - "then": { - "if": { - "contains": {"const": "c"} - } - } - }, - "unevaluatedItems": false - }, - "tests": [ - { - "description": "empty array is valid", - "data": [], - "valid": true - }, - { - "description": "only a's are valid", - "data": [ "a", "a" ], - "valid": true - }, - { - "description": "a's and b's are valid", - "data": [ "a", "b", "a", "b", "a" ], - "valid": true - }, - { - "description": "a's, b's and c's are valid", - "data": [ "c", "a", "c", "c", "b", "a" ], - "valid": true - }, - { - "description": "only b's are invalid", - "data": [ "b", "b" ], - "valid": false - }, - { - "description": "only c's are invalid", - "data": [ "c", "c" ], - "valid": false - }, - { - "description": "only b's and c's are invalid", - "data": [ "c", "b", "c", "b", "c" ], - "valid": false - }, - { - "description": "only a's and c's are invalid", - "data": [ "c", "a", "c", "a", "c" ], - "valid": false - } - ] - }, - { - "description": "non-array instances are valid", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedItems": false - }, - "tests": [ - { - "description": "ignores booleans", - "data": true, - "valid": true - }, - { - "description": "ignores integers", - "data": 123, - "valid": true - }, - { - "description": "ignores floats", - "data": 1.0, - "valid": true - }, - { - "description": "ignores objects", - "data": {}, - "valid": true - }, - { - "description": "ignores strings", - "data": "foo", - "valid": true - }, - { - "description": "ignores null", - "data": null, - "valid": true - } - ] - }, - { - "description": "unevaluatedItems with null instance elements", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedItems": { - "type": "null" - } - }, - "tests": [ - { - "description": "allows null elements", - "data": [ null ], - "valid": true - } - ] - }, - { - "description": "unevaluatedItems can see annotations from if without then and else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": { - "prefixItems": [{"const": "a"}] - }, - "unevaluatedItems": false - }, - "tests": [ - { - "description": "valid in case if is evaluated", - "data": [ "a" ], - "valid": true - }, - { - "description": "invalid in case if is evaluated", - "data": [ "b" ], - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/unevaluatedProperties.json b/jsonschema/testdata/draft2020-12/unevaluatedProperties.json deleted file mode 100644 index ae29c9eb..00000000 --- a/jsonschema/testdata/draft2020-12/unevaluatedProperties.json +++ /dev/null @@ -1,1601 +0,0 @@ -[ - { - "description": "unevaluatedProperties true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "unevaluatedProperties": true - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": {}, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - } - ] - }, - { - "description": "unevaluatedProperties schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "unevaluatedProperties": { - "type": "string", - "minLength": 3 - } - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": {}, - "valid": true - }, - { - "description": "with valid unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with invalid unevaluated properties", - "data": { - "foo": "fo" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": {}, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with adjacent properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with adjacent patternProperties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "patternProperties": { - "^foo": { "type": "string" } - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with adjacent additionalProperties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "additionalProperties": true, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no additional properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with additional properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - } - ] - }, - { - "description": "unevaluatedProperties with nested properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "allOf": [ - { - "properties": { - "bar": { "type": "string" } - } - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no additional properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "with additional properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with nested patternProperties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "allOf": [ - { - "patternProperties": { - "^bar": { "type": "string" } - } - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no additional properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "with additional properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with nested additionalProperties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "allOf": [ - { - "additionalProperties": true - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no additional properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with additional properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - } - ] - }, - { - "description": "unevaluatedProperties with nested unevaluatedProperties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "allOf": [ - { - "unevaluatedProperties": true - } - ], - "unevaluatedProperties": { - "type": "string", - "maxLength": 2 - } - }, - "tests": [ - { - "description": "with no nested unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with nested unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - } - ] - }, - { - "description": "unevaluatedProperties with anyOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "anyOf": [ - { - "properties": { - "bar": { "const": "bar" } - }, - "required": ["bar"] - }, - { - "properties": { - "baz": { "const": "baz" } - }, - "required": ["baz"] - }, - { - "properties": { - "quux": { "const": "quux" } - }, - "required": ["quux"] - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "when one matches and has no unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "when one matches and has unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "not-baz" - }, - "valid": false - }, - { - "description": "when two match and has no unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "baz" - }, - "valid": true - }, - { - "description": "when two match and has unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "baz", - "quux": "not-quux" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with oneOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "oneOf": [ - { - "properties": { - "bar": { "const": "bar" } - }, - "required": ["bar"] - }, - { - "properties": { - "baz": { "const": "baz" } - }, - "required": ["baz"] - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar", - "quux": "quux" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with not", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "not": { - "not": { - "properties": { - "bar": { "const": "bar" } - }, - "required": ["bar"] - } - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with if/then/else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "if": { - "properties": { - "foo": { "const": "then" } - }, - "required": ["foo"] - }, - "then": { - "properties": { - "bar": { "type": "string" } - }, - "required": ["bar"] - }, - "else": { - "properties": { - "baz": { "type": "string" } - }, - "required": ["baz"] - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "when if is true and has no unevaluated properties", - "data": { - "foo": "then", - "bar": "bar" - }, - "valid": true - }, - { - "description": "when if is true and has unevaluated properties", - "data": { - "foo": "then", - "bar": "bar", - "baz": "baz" - }, - "valid": false - }, - { - "description": "when if is false and has no unevaluated properties", - "data": { - "baz": "baz" - }, - "valid": true - }, - { - "description": "when if is false and has unevaluated properties", - "data": { - "foo": "else", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with if/then/else, then not defined", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "if": { - "properties": { - "foo": { "const": "then" } - }, - "required": ["foo"] - }, - "else": { - "properties": { - "baz": { "type": "string" } - }, - "required": ["baz"] - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "when if is true and has no unevaluated properties", - "data": { - "foo": "then", - "bar": "bar" - }, - "valid": false - }, - { - "description": "when if is true and has unevaluated properties", - "data": { - "foo": "then", - "bar": "bar", - "baz": "baz" - }, - "valid": false - }, - { - "description": "when if is false and has no unevaluated properties", - "data": { - "baz": "baz" - }, - "valid": true - }, - { - "description": "when if is false and has unevaluated properties", - "data": { - "foo": "else", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with if/then/else, else not defined", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "if": { - "properties": { - "foo": { "const": "then" } - }, - "required": ["foo"] - }, - "then": { - "properties": { - "bar": { "type": "string" } - }, - "required": ["bar"] - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "when if is true and has no unevaluated properties", - "data": { - "foo": "then", - "bar": "bar" - }, - "valid": true - }, - { - "description": "when if is true and has unevaluated properties", - "data": { - "foo": "then", - "bar": "bar", - "baz": "baz" - }, - "valid": false - }, - { - "description": "when if is false and has no unevaluated properties", - "data": { - "baz": "baz" - }, - "valid": false - }, - { - "description": "when if is false and has unevaluated properties", - "data": { - "foo": "else", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with dependentSchemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "dependentSchemas": { - "foo": { - "properties": { - "bar": { "const": "bar" } - }, - "required": ["bar"] - } - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with boolean schemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "allOf": [true], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with $ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "$ref": "#/$defs/bar", - "properties": { - "foo": { "type": "string" } - }, - "unevaluatedProperties": false, - "$defs": { - "bar": { - "properties": { - "bar": { "type": "string" } - } - } - } - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties before $ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "unevaluatedProperties": false, - "properties": { - "foo": { "type": "string" } - }, - "$ref": "#/$defs/bar", - "$defs": { - "bar": { - "properties": { - "bar": { "type": "string" } - } - } - } - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with $dynamicRef", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://example.com/unevaluated-properties-with-dynamic-ref/derived", - - "$ref": "./baseSchema", - - "$defs": { - "derived": { - "$dynamicAnchor": "addons", - "properties": { - "bar": { "type": "string" } - } - }, - "baseSchema": { - "$id": "./baseSchema", - - "$comment": "unevaluatedProperties comes first so it's more likely to catch bugs with implementations that are sensitive to keyword ordering", - "unevaluatedProperties": false, - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "$dynamicRef": "#addons", - - "$defs": { - "defaultAddons": { - "$comment": "Needed to satisfy the bookending requirement", - "$dynamicAnchor": "addons" - } - } - } - } - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties can't see inside cousins", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "properties": { - "foo": true - } - }, - { - "unevaluatedProperties": false - } - ] - }, - "tests": [ - { - "description": "always fails", - "data": { - "foo": 1 - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties can't see inside cousins (reverse order)", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "unevaluatedProperties": false - }, - { - "properties": { - "foo": true - } - } - ] - }, - "tests": [ - { - "description": "always fails", - "data": { - "foo": 1 - }, - "valid": false - } - ] - }, - { - "description": "nested unevaluatedProperties, outer false, inner true, properties outside", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "allOf": [ - { - "unevaluatedProperties": true - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no nested unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with nested unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - } - ] - }, - { - "description": "nested unevaluatedProperties, outer false, inner true, properties inside", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "allOf": [ - { - "properties": { - "foo": { "type": "string" } - }, - "unevaluatedProperties": true - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no nested unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with nested unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - } - ] - }, - { - "description": "nested unevaluatedProperties, outer true, inner false, properties outside", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "allOf": [ - { - "unevaluatedProperties": false - } - ], - "unevaluatedProperties": true - }, - "tests": [ - { - "description": "with no nested unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": false - }, - { - "description": "with nested unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "nested unevaluatedProperties, outer true, inner false, properties inside", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "allOf": [ - { - "properties": { - "foo": { "type": "string" } - }, - "unevaluatedProperties": false - } - ], - "unevaluatedProperties": true - }, - "tests": [ - { - "description": "with no nested unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with nested unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "cousin unevaluatedProperties, true and false, true with properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "allOf": [ - { - "properties": { - "foo": { "type": "string" } - }, - "unevaluatedProperties": true - }, - { - "unevaluatedProperties": false - } - ] - }, - "tests": [ - { - "description": "with no nested unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": false - }, - { - "description": "with nested unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "cousin unevaluatedProperties, true and false, false with properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "allOf": [ - { - "unevaluatedProperties": true - }, - { - "properties": { - "foo": { "type": "string" } - }, - "unevaluatedProperties": false - } - ] - }, - "tests": [ - { - "description": "with no nested unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with nested unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "property is evaluated in an uncle schema to unevaluatedProperties", - "comment": "see https://stackoverflow.com/questions/66936884/deeply-nested-unevaluatedproperties-and-their-expectations", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { - "type": "object", - "properties": { - "bar": { - "type": "string" - } - }, - "unevaluatedProperties": false - } - }, - "anyOf": [ - { - "properties": { - "foo": { - "properties": { - "faz": { - "type": "string" - } - } - } - } - } - ] - }, - "tests": [ - { - "description": "no extra properties", - "data": { - "foo": { - "bar": "test" - } - }, - "valid": true - }, - { - "description": "uncle keyword evaluation is not significant", - "data": { - "foo": { - "bar": "test", - "faz": "test" - } - }, - "valid": false - } - ] - }, - { - "description": "in-place applicator siblings, allOf has unevaluated", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "allOf": [ - { - "properties": { - "foo": true - }, - "unevaluatedProperties": false - } - ], - "anyOf": [ - { - "properties": { - "bar": true - } - } - ] - }, - "tests": [ - { - "description": "base case: both properties present", - "data": { - "foo": 1, - "bar": 1 - }, - "valid": false - }, - { - "description": "in place applicator siblings, bar is missing", - "data": { - "foo": 1 - }, - "valid": true - }, - { - "description": "in place applicator siblings, foo is missing", - "data": { - "bar": 1 - }, - "valid": false - } - ] - }, - { - "description": "in-place applicator siblings, anyOf has unevaluated", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "allOf": [ - { - "properties": { - "foo": true - } - } - ], - "anyOf": [ - { - "properties": { - "bar": true - }, - "unevaluatedProperties": false - } - ] - }, - "tests": [ - { - "description": "base case: both properties present", - "data": { - "foo": 1, - "bar": 1 - }, - "valid": false - }, - { - "description": "in place applicator siblings, bar is missing", - "data": { - "foo": 1 - }, - "valid": false - }, - { - "description": "in place applicator siblings, foo is missing", - "data": { - "bar": 1 - }, - "valid": true - } - ] - }, - { - "description": "unevaluatedProperties + single cyclic ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "x": { "$ref": "#" } - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "Empty is valid", - "data": {}, - "valid": true - }, - { - "description": "Single is valid", - "data": { "x": {} }, - "valid": true - }, - { - "description": "Unevaluated on 1st level is invalid", - "data": { "x": {}, "y": {} }, - "valid": false - }, - { - "description": "Nested is valid", - "data": { "x": { "x": {} } }, - "valid": true - }, - { - "description": "Unevaluated on 2nd level is invalid", - "data": { "x": { "x": {}, "y": {} } }, - "valid": false - }, - { - "description": "Deep nested is valid", - "data": { "x": { "x": { "x": {} } } }, - "valid": true - }, - { - "description": "Unevaluated on 3rd level is invalid", - "data": { "x": { "x": { "x": {}, "y": {} } } }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties + ref inside allOf / oneOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "one": { - "properties": { "a": true } - }, - "two": { - "required": ["x"], - "properties": { "x": true } - } - }, - "allOf": [ - { "$ref": "#/$defs/one" }, - { "properties": { "b": true } }, - { - "oneOf": [ - { "$ref": "#/$defs/two" }, - { - "required": ["y"], - "properties": { "y": true } - } - ] - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "Empty is invalid (no x or y)", - "data": {}, - "valid": false - }, - { - "description": "a and b are invalid (no x or y)", - "data": { "a": 1, "b": 1 }, - "valid": false - }, - { - "description": "x and y are invalid", - "data": { "x": 1, "y": 1 }, - "valid": false - }, - { - "description": "a and x are valid", - "data": { "a": 1, "x": 1 }, - "valid": true - }, - { - "description": "a and y are valid", - "data": { "a": 1, "y": 1 }, - "valid": true - }, - { - "description": "a and b and x are valid", - "data": { "a": 1, "b": 1, "x": 1 }, - "valid": true - }, - { - "description": "a and b and y are valid", - "data": { "a": 1, "b": 1, "y": 1 }, - "valid": true - }, - { - "description": "a and b and x and y are invalid", - "data": { "a": 1, "b": 1, "x": 1, "y": 1 }, - "valid": false - } - ] - }, - { - "description": "dynamic evalation inside nested refs", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "one": { - "oneOf": [ - { "$ref": "#/$defs/two" }, - { "required": ["b"], "properties": { "b": true } }, - { "required": ["xx"], "patternProperties": { "x": true } }, - { "required": ["all"], "unevaluatedProperties": true } - ] - }, - "two": { - "oneOf": [ - { "required": ["c"], "properties": { "c": true } }, - { "required": ["d"], "properties": { "d": true } } - ] - } - }, - "oneOf": [ - { "$ref": "#/$defs/one" }, - { "required": ["a"], "properties": { "a": true } } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "Empty is invalid", - "data": {}, - "valid": false - }, - { - "description": "a is valid", - "data": { "a": 1 }, - "valid": true - }, - { - "description": "b is valid", - "data": { "b": 1 }, - "valid": true - }, - { - "description": "c is valid", - "data": { "c": 1 }, - "valid": true - }, - { - "description": "d is valid", - "data": { "d": 1 }, - "valid": true - }, - { - "description": "a + b is invalid", - "data": { "a": 1, "b": 1 }, - "valid": false - }, - { - "description": "a + c is invalid", - "data": { "a": 1, "c": 1 }, - "valid": false - }, - { - "description": "a + d is invalid", - "data": { "a": 1, "d": 1 }, - "valid": false - }, - { - "description": "b + c is invalid", - "data": { "b": 1, "c": 1 }, - "valid": false - }, - { - "description": "b + d is invalid", - "data": { "b": 1, "d": 1 }, - "valid": false - }, - { - "description": "c + d is invalid", - "data": { "c": 1, "d": 1 }, - "valid": false - }, - { - "description": "xx is valid", - "data": { "xx": 1 }, - "valid": true - }, - { - "description": "xx + foox is valid", - "data": { "xx": 1, "foox": 1 }, - "valid": true - }, - { - "description": "xx + foo is invalid", - "data": { "xx": 1, "foo": 1 }, - "valid": false - }, - { - "description": "xx + a is invalid", - "data": { "xx": 1, "a": 1 }, - "valid": false - }, - { - "description": "xx + b is invalid", - "data": { "xx": 1, "b": 1 }, - "valid": false - }, - { - "description": "xx + c is invalid", - "data": { "xx": 1, "c": 1 }, - "valid": false - }, - { - "description": "xx + d is invalid", - "data": { "xx": 1, "d": 1 }, - "valid": false - }, - { - "description": "all is valid", - "data": { "all": 1 }, - "valid": true - }, - { - "description": "all + foo is valid", - "data": { "all": 1, "foo": 1 }, - "valid": true - }, - { - "description": "all + a is invalid", - "data": { "all": 1, "a": 1 }, - "valid": false - } - ] - }, - { - "description": "non-object instances are valid", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "ignores booleans", - "data": true, - "valid": true - }, - { - "description": "ignores integers", - "data": 123, - "valid": true - }, - { - "description": "ignores floats", - "data": 1.0, - "valid": true - }, - { - "description": "ignores arrays", - "data": [], - "valid": true - }, - { - "description": "ignores strings", - "data": "foo", - "valid": true - }, - { - "description": "ignores null", - "data": null, - "valid": true - } - ] - }, - { - "description": "unevaluatedProperties with null valued instance properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedProperties": { - "type": "null" - } - }, - "tests": [ - { - "description": "allows null valued properties", - "data": {"foo": null}, - "valid": true - } - ] - }, - { - "description": "unevaluatedProperties not affected by propertyNames", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": {"maxLength": 1}, - "unevaluatedProperties": { - "type": "number" - } - }, - "tests": [ - { - "description": "allows only number properties", - "data": {"a": 1}, - "valid": true - }, - { - "description": "string property is invalid", - "data": {"a": "b"}, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties can see annotations from if without then and else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": { - "patternProperties": { - "foo": { - "type": "string" - } - } - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "valid in case if is evaluated", - "data": { - "foo": "a" - }, - "valid": true - }, - { - "description": "invalid in case if is evaluated", - "data": { - "bar": "a" - }, - "valid": false - } - ] - }, - { - "description": "dependentSchemas with unevaluatedProperties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": {"foo2": {}}, - "dependentSchemas": { - "foo" : {}, - "foo2": { - "properties": { - "bar":{} - } - } - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "unevaluatedProperties doesn't consider dependentSchemas", - "data": {"foo": ""}, - "valid": false - }, - { - "description": "unevaluatedProperties doesn't see bar when foo2 is absent", - "data": {"bar": ""}, - "valid": false - }, - { - "description": "unevaluatedProperties sees bar when foo2 is present", - "data": { "foo2": "", "bar": ""}, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/uniqueItems.json b/jsonschema/testdata/draft2020-12/uniqueItems.json deleted file mode 100644 index 4ea3bf98..00000000 --- a/jsonschema/testdata/draft2020-12/uniqueItems.json +++ /dev/null @@ -1,419 +0,0 @@ -[ - { - "description": "uniqueItems validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "uniqueItems": true - }, - "tests": [ - { - "description": "unique array of integers is valid", - "data": [1, 2], - "valid": true - }, - { - "description": "non-unique array of integers is invalid", - "data": [1, 1], - "valid": false - }, - { - "description": "non-unique array of more than two integers is invalid", - "data": [1, 2, 1], - "valid": false - }, - { - "description": "numbers are unique if mathematically unequal", - "data": [1.0, 1.00, 1], - "valid": false - }, - { - "description": "false is not equal to zero", - "data": [0, false], - "valid": true - }, - { - "description": "true is not equal to one", - "data": [1, true], - "valid": true - }, - { - "description": "unique array of strings is valid", - "data": ["foo", "bar", "baz"], - "valid": true - }, - { - "description": "non-unique array of strings is invalid", - "data": ["foo", "bar", "foo"], - "valid": false - }, - { - "description": "unique array of objects is valid", - "data": [{"foo": "bar"}, {"foo": "baz"}], - "valid": true - }, - { - "description": "non-unique array of objects is invalid", - "data": [{"foo": "bar"}, {"foo": "bar"}], - "valid": false - }, - { - "description": "property order of array of objects is ignored", - "data": [{"foo": "bar", "bar": "foo"}, {"bar": "foo", "foo": "bar"}], - "valid": false - }, - { - "description": "unique array of nested objects is valid", - "data": [ - {"foo": {"bar" : {"baz" : true}}}, - {"foo": {"bar" : {"baz" : false}}} - ], - "valid": true - }, - { - "description": "non-unique array of nested objects is invalid", - "data": [ - {"foo": {"bar" : {"baz" : true}}}, - {"foo": {"bar" : {"baz" : true}}} - ], - "valid": false - }, - { - "description": "unique array of arrays is valid", - "data": [["foo"], ["bar"]], - "valid": true - }, - { - "description": "non-unique array of arrays is invalid", - "data": [["foo"], ["foo"]], - "valid": false - }, - { - "description": "non-unique array of more than two arrays is invalid", - "data": [["foo"], ["bar"], ["foo"]], - "valid": false - }, - { - "description": "1 and true are unique", - "data": [1, true], - "valid": true - }, - { - "description": "0 and false are unique", - "data": [0, false], - "valid": true - }, - { - "description": "[1] and [true] are unique", - "data": [[1], [true]], - "valid": true - }, - { - "description": "[0] and [false] are unique", - "data": [[0], [false]], - "valid": true - }, - { - "description": "nested [1] and [true] are unique", - "data": [[[1], "foo"], [[true], "foo"]], - "valid": true - }, - { - "description": "nested [0] and [false] are unique", - "data": [[[0], "foo"], [[false], "foo"]], - "valid": true - }, - { - "description": "unique heterogeneous types are valid", - "data": [{}, [1], true, null, 1, "{}"], - "valid": true - }, - { - "description": "non-unique heterogeneous types are invalid", - "data": [{}, [1], true, null, {}, 1], - "valid": false - }, - { - "description": "different objects are unique", - "data": [{"a": 1, "b": 2}, {"a": 2, "b": 1}], - "valid": true - }, - { - "description": "objects are non-unique despite key order", - "data": [{"a": 1, "b": 2}, {"b": 2, "a": 1}], - "valid": false - }, - { - "description": "{\"a\": false} and {\"a\": 0} are unique", - "data": [{"a": false}, {"a": 0}], - "valid": true - }, - { - "description": "{\"a\": true} and {\"a\": 1} are unique", - "data": [{"a": true}, {"a": 1}], - "valid": true - } - ] - }, - { - "description": "uniqueItems with an array of items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [{"type": "boolean"}, {"type": "boolean"}], - "uniqueItems": true - }, - "tests": [ - { - "description": "[false, true] from items array is valid", - "data": [false, true], - "valid": true - }, - { - "description": "[true, false] from items array is valid", - "data": [true, false], - "valid": true - }, - { - "description": "[false, false] from items array is not valid", - "data": [false, false], - "valid": false - }, - { - "description": "[true, true] from items array is not valid", - "data": [true, true], - "valid": false - }, - { - "description": "unique array extended from [false, true] is valid", - "data": [false, true, "foo", "bar"], - "valid": true - }, - { - "description": "unique array extended from [true, false] is valid", - "data": [true, false, "foo", "bar"], - "valid": true - }, - { - "description": "non-unique array extended from [false, true] is not valid", - "data": [false, true, "foo", "foo"], - "valid": false - }, - { - "description": "non-unique array extended from [true, false] is not valid", - "data": [true, false, "foo", "foo"], - "valid": false - } - ] - }, - { - "description": "uniqueItems with an array of items and additionalItems=false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [{"type": "boolean"}, {"type": "boolean"}], - "uniqueItems": true, - "items": false - }, - "tests": [ - { - "description": "[false, true] from items array is valid", - "data": [false, true], - "valid": true - }, - { - "description": "[true, false] from items array is valid", - "data": [true, false], - "valid": true - }, - { - "description": "[false, false] from items array is not valid", - "data": [false, false], - "valid": false - }, - { - "description": "[true, true] from items array is not valid", - "data": [true, true], - "valid": false - }, - { - "description": "extra items are invalid even if unique", - "data": [false, true, null], - "valid": false - } - ] - }, - { - "description": "uniqueItems=false validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "uniqueItems": false - }, - "tests": [ - { - "description": "unique array of integers is valid", - "data": [1, 2], - "valid": true - }, - { - "description": "non-unique array of integers is valid", - "data": [1, 1], - "valid": true - }, - { - "description": "numbers are unique if mathematically unequal", - "data": [1.0, 1.00, 1], - "valid": true - }, - { - "description": "false is not equal to zero", - "data": [0, false], - "valid": true - }, - { - "description": "true is not equal to one", - "data": [1, true], - "valid": true - }, - { - "description": "unique array of objects is valid", - "data": [{"foo": "bar"}, {"foo": "baz"}], - "valid": true - }, - { - "description": "non-unique array of objects is valid", - "data": [{"foo": "bar"}, {"foo": "bar"}], - "valid": true - }, - { - "description": "unique array of nested objects is valid", - "data": [ - {"foo": {"bar" : {"baz" : true}}}, - {"foo": {"bar" : {"baz" : false}}} - ], - "valid": true - }, - { - "description": "non-unique array of nested objects is valid", - "data": [ - {"foo": {"bar" : {"baz" : true}}}, - {"foo": {"bar" : {"baz" : true}}} - ], - "valid": true - }, - { - "description": "unique array of arrays is valid", - "data": [["foo"], ["bar"]], - "valid": true - }, - { - "description": "non-unique array of arrays is valid", - "data": [["foo"], ["foo"]], - "valid": true - }, - { - "description": "1 and true are unique", - "data": [1, true], - "valid": true - }, - { - "description": "0 and false are unique", - "data": [0, false], - "valid": true - }, - { - "description": "unique heterogeneous types are valid", - "data": [{}, [1], true, null, 1], - "valid": true - }, - { - "description": "non-unique heterogeneous types are valid", - "data": [{}, [1], true, null, {}, 1], - "valid": true - } - ] - }, - { - "description": "uniqueItems=false with an array of items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [{"type": "boolean"}, {"type": "boolean"}], - "uniqueItems": false - }, - "tests": [ - { - "description": "[false, true] from items array is valid", - "data": [false, true], - "valid": true - }, - { - "description": "[true, false] from items array is valid", - "data": [true, false], - "valid": true - }, - { - "description": "[false, false] from items array is valid", - "data": [false, false], - "valid": true - }, - { - "description": "[true, true] from items array is valid", - "data": [true, true], - "valid": true - }, - { - "description": "unique array extended from [false, true] is valid", - "data": [false, true, "foo", "bar"], - "valid": true - }, - { - "description": "unique array extended from [true, false] is valid", - "data": [true, false, "foo", "bar"], - "valid": true - }, - { - "description": "non-unique array extended from [false, true] is valid", - "data": [false, true, "foo", "foo"], - "valid": true - }, - { - "description": "non-unique array extended from [true, false] is valid", - "data": [true, false, "foo", "foo"], - "valid": true - } - ] - }, - { - "description": "uniqueItems=false with an array of items and additionalItems=false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [{"type": "boolean"}, {"type": "boolean"}], - "uniqueItems": false, - "items": false - }, - "tests": [ - { - "description": "[false, true] from items array is valid", - "data": [false, true], - "valid": true - }, - { - "description": "[true, false] from items array is valid", - "data": [true, false], - "valid": true - }, - { - "description": "[false, false] from items array is valid", - "data": [false, false], - "valid": true - }, - { - "description": "[true, true] from items array is valid", - "data": [true, true], - "valid": true - }, - { - "description": "extra items are invalid even if unique", - "data": [false, true, null], - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/remotes/README.md b/jsonschema/testdata/remotes/README.md deleted file mode 100644 index 8a641dbd..00000000 --- a/jsonschema/testdata/remotes/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# JSON Schema test suite: remote references - -These files were copied from -https://github.com/json-schema-org/JSON-Schema-Test-Suite/tree/83e866b46c9f9e7082fd51e83a61c5f2145a1ab7/remotes. diff --git a/jsonschema/testdata/remotes/different-id-ref-string.json b/jsonschema/testdata/remotes/different-id-ref-string.json deleted file mode 100644 index 7f888609..00000000 --- a/jsonschema/testdata/remotes/different-id-ref-string.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "$id": "http://localhost:1234/real-id-ref-string.json", - "$defs": {"bar": {"type": "string"}}, - "$ref": "#/$defs/bar" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/baseUriChange/folderInteger.json b/jsonschema/testdata/remotes/draft2020-12/baseUriChange/folderInteger.json deleted file mode 100644 index 1f44a631..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/baseUriChange/folderInteger.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "integer" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolder/folderInteger.json b/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolder/folderInteger.json deleted file mode 100644 index 1f44a631..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolder/folderInteger.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "integer" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolderInSubschema/folderInteger.json b/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolderInSubschema/folderInteger.json deleted file mode 100644 index 1f44a631..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolderInSubschema/folderInteger.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "integer" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/detached-dynamicref.json b/jsonschema/testdata/remotes/draft2020-12/detached-dynamicref.json deleted file mode 100644 index 07cce1da..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/detached-dynamicref.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "$id": "http://localhost:1234/draft2020-12/detached-dynamicref.json", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "foo": { - "$dynamicRef": "#detached" - }, - "detached": { - "$dynamicAnchor": "detached", - "type": "integer" - } - } -} \ No newline at end of file diff --git a/jsonschema/testdata/remotes/draft2020-12/detached-ref.json b/jsonschema/testdata/remotes/draft2020-12/detached-ref.json deleted file mode 100644 index 9c2dca93..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/detached-ref.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "$id": "http://localhost:1234/draft2020-12/detached-ref.json", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "foo": { - "$ref": "#detached" - }, - "detached": { - "$anchor": "detached", - "type": "integer" - } - } -} \ No newline at end of file diff --git a/jsonschema/testdata/remotes/draft2020-12/extendible-dynamic-ref.json b/jsonschema/testdata/remotes/draft2020-12/extendible-dynamic-ref.json deleted file mode 100644 index 65bc0c21..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/extendible-dynamic-ref.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "description": "extendible array", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/extendible-dynamic-ref.json", - "type": "object", - "properties": { - "elements": { - "type": "array", - "items": { - "$dynamicRef": "#elements" - } - } - }, - "required": ["elements"], - "additionalProperties": false, - "$defs": { - "elements": { - "$dynamicAnchor": "elements" - } - } -} diff --git a/jsonschema/testdata/remotes/draft2020-12/format-assertion-false.json b/jsonschema/testdata/remotes/draft2020-12/format-assertion-false.json deleted file mode 100644 index 43a711c9..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/format-assertion-false.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "$id": "http://localhost:1234/draft2020-12/format-assertion-false.json", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$vocabulary": { - "https://json-schema.org/draft/2020-12/vocab/core": true, - "https://json-schema.org/draft/2020-12/vocab/format-assertion": false - }, - "$dynamicAnchor": "meta", - "allOf": [ - { "$ref": "https://json-schema.org/draft/2020-12/meta/core" }, - { "$ref": "https://json-schema.org/draft/2020-12/meta/format-assertion" } - ] -} diff --git a/jsonschema/testdata/remotes/draft2020-12/format-assertion-true.json b/jsonschema/testdata/remotes/draft2020-12/format-assertion-true.json deleted file mode 100644 index 39c6b0ab..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/format-assertion-true.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "$id": "http://localhost:1234/draft2020-12/format-assertion-true.json", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$vocabulary": { - "https://json-schema.org/draft/2020-12/vocab/core": true, - "https://json-schema.org/draft/2020-12/vocab/format-assertion": true - }, - "$dynamicAnchor": "meta", - "allOf": [ - { "$ref": "https://json-schema.org/draft/2020-12/meta/core" }, - { "$ref": "https://json-schema.org/draft/2020-12/meta/format-assertion" } - ] -} diff --git a/jsonschema/testdata/remotes/draft2020-12/integer.json b/jsonschema/testdata/remotes/draft2020-12/integer.json deleted file mode 100644 index 1f44a631..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/integer.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "integer" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/locationIndependentIdentifier.json b/jsonschema/testdata/remotes/draft2020-12/locationIndependentIdentifier.json deleted file mode 100644 index 6565a1ee..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/locationIndependentIdentifier.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "refToInteger": { - "$ref": "#foo" - }, - "A": { - "$anchor": "foo", - "type": "integer" - } - } -} diff --git a/jsonschema/testdata/remotes/draft2020-12/metaschema-no-validation.json b/jsonschema/testdata/remotes/draft2020-12/metaschema-no-validation.json deleted file mode 100644 index 71be8b5d..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/metaschema-no-validation.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/metaschema-no-validation.json", - "$vocabulary": { - "https://json-schema.org/draft/2020-12/vocab/applicator": true, - "https://json-schema.org/draft/2020-12/vocab/core": true - }, - "$dynamicAnchor": "meta", - "allOf": [ - { "$ref": "https://json-schema.org/draft/2020-12/meta/applicator" }, - { "$ref": "https://json-schema.org/draft/2020-12/meta/core" } - ] -} diff --git a/jsonschema/testdata/remotes/draft2020-12/metaschema-optional-vocabulary.json b/jsonschema/testdata/remotes/draft2020-12/metaschema-optional-vocabulary.json deleted file mode 100644 index a6963e54..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/metaschema-optional-vocabulary.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/metaschema-optional-vocabulary.json", - "$vocabulary": { - "https://json-schema.org/draft/2020-12/vocab/validation": true, - "https://json-schema.org/draft/2020-12/vocab/core": true, - "http://localhost:1234/draft/2020-12/vocab/custom": false - }, - "$dynamicAnchor": "meta", - "allOf": [ - { "$ref": "https://json-schema.org/draft/2020-12/meta/validation" }, - { "$ref": "https://json-schema.org/draft/2020-12/meta/core" } - ] -} diff --git a/jsonschema/testdata/remotes/draft2020-12/name-defs.json b/jsonschema/testdata/remotes/draft2020-12/name-defs.json deleted file mode 100644 index 67bc33c5..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/name-defs.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "orNull": { - "anyOf": [ - { - "type": "null" - }, - { - "$ref": "#" - } - ] - } - }, - "type": "string" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/nested/foo-ref-string.json b/jsonschema/testdata/remotes/draft2020-12/nested/foo-ref-string.json deleted file mode 100644 index 29661ff9..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/nested/foo-ref-string.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": {"$ref": "string.json"} - } -} diff --git a/jsonschema/testdata/remotes/draft2020-12/nested/string.json b/jsonschema/testdata/remotes/draft2020-12/nested/string.json deleted file mode 100644 index 6607ac53..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/nested/string.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "string" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/prefixItems.json b/jsonschema/testdata/remotes/draft2020-12/prefixItems.json deleted file mode 100644 index acd8293c..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/prefixItems.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "$id": "http://localhost:1234/draft2020-12/prefixItems.json", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - {"type": "string"} - ] -} diff --git a/jsonschema/testdata/remotes/draft2020-12/ref-and-defs.json b/jsonschema/testdata/remotes/draft2020-12/ref-and-defs.json deleted file mode 100644 index 16d30fa3..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/ref-and-defs.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/ref-and-defs.json", - "$defs": { - "inner": { - "properties": { - "bar": { "type": "string" } - } - } - }, - "$ref": "#/$defs/inner" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/subSchemas.json b/jsonschema/testdata/remotes/draft2020-12/subSchemas.json deleted file mode 100644 index 1bb4846d..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/subSchemas.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "integer": { - "type": "integer" - }, - "refToInteger": { - "$ref": "#/$defs/integer" - } - } -} diff --git a/jsonschema/testdata/remotes/draft2020-12/tree.json b/jsonschema/testdata/remotes/draft2020-12/tree.json deleted file mode 100644 index b07555fb..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/tree.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "description": "tree schema, extensible", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/tree.json", - "$dynamicAnchor": "node", - - "type": "object", - "properties": { - "data": true, - "children": { - "type": "array", - "items": { - "$dynamicRef": "#node" - } - } - } -} diff --git a/jsonschema/testdata/remotes/nested-absolute-ref-to-string.json b/jsonschema/testdata/remotes/nested-absolute-ref-to-string.json deleted file mode 100644 index f46c7616..00000000 --- a/jsonschema/testdata/remotes/nested-absolute-ref-to-string.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "$defs": { - "bar": { - "$id": "http://localhost:1234/the-nested-id.json", - "type": "string" - } - }, - "$ref": "http://localhost:1234/the-nested-id.json" -} diff --git a/jsonschema/testdata/remotes/urn-ref-string.json b/jsonschema/testdata/remotes/urn-ref-string.json deleted file mode 100644 index aca2211b..00000000 --- a/jsonschema/testdata/remotes/urn-ref-string.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "$id": "urn:uuid:feebdaed-ffff-0000-ffff-0000deadbeef", - "$defs": {"bar": {"type": "string"}}, - "$ref": "#/$defs/bar" -} diff --git a/jsonschema/util.go b/jsonschema/util.go deleted file mode 100644 index 25b916cd..00000000 --- a/jsonschema/util.go +++ /dev/null @@ -1,463 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "bytes" - "cmp" - "encoding/binary" - "encoding/json" - "fmt" - "hash/maphash" - "math" - "math/big" - "reflect" - "slices" - "strings" - "sync" -) - -// Equal reports whether two Go values representing JSON values are equal according -// to the JSON Schema spec. -// The values must not contain cycles. -// See https://json-schema.org/draft/2020-12/json-schema-core#section-4.2.2. -// It behaves like reflect.DeepEqual, except that numbers are compared according -// to mathematical equality. -func Equal(x, y any) bool { - return equalValue(reflect.ValueOf(x), reflect.ValueOf(y)) -} - -func equalValue(x, y reflect.Value) bool { - // Copied from src/reflect/deepequal.go, omitting the visited check (because JSON - // values are trees). - if !x.IsValid() || !y.IsValid() { - return x.IsValid() == y.IsValid() - } - - // Treat numbers specially. - rx, ok1 := jsonNumber(x) - ry, ok2 := jsonNumber(y) - if ok1 && ok2 { - return rx.Cmp(ry) == 0 - } - if x.Kind() != y.Kind() { - return false - } - switch x.Kind() { - case reflect.Array: - if x.Len() != y.Len() { - return false - } - for i := range x.Len() { - if !equalValue(x.Index(i), y.Index(i)) { - return false - } - } - return true - case reflect.Slice: - if x.IsNil() != y.IsNil() { - return false - } - if x.Len() != y.Len() { - return false - } - if x.UnsafePointer() == y.UnsafePointer() { - return true - } - // Special case for []byte, which is common. - if x.Type().Elem().Kind() == reflect.Uint8 && x.Type() == y.Type() { - return bytes.Equal(x.Bytes(), y.Bytes()) - } - for i := range x.Len() { - if !equalValue(x.Index(i), y.Index(i)) { - return false - } - } - return true - case reflect.Interface: - if x.IsNil() || y.IsNil() { - return x.IsNil() == y.IsNil() - } - return equalValue(x.Elem(), y.Elem()) - case reflect.Pointer: - if x.UnsafePointer() == y.UnsafePointer() { - return true - } - return equalValue(x.Elem(), y.Elem()) - case reflect.Struct: - t := x.Type() - if t != y.Type() { - return false - } - for i := range t.NumField() { - sf := t.Field(i) - if !sf.IsExported() { - continue - } - if !equalValue(x.FieldByIndex(sf.Index), y.FieldByIndex(sf.Index)) { - return false - } - } - return true - case reflect.Map: - if x.IsNil() != y.IsNil() { - return false - } - if x.Len() != y.Len() { - return false - } - if x.UnsafePointer() == y.UnsafePointer() { - return true - } - iter := x.MapRange() - for iter.Next() { - vx := iter.Value() - vy := y.MapIndex(iter.Key()) - if !vy.IsValid() || !equalValue(vx, vy) { - return false - } - } - return true - case reflect.Func: - if x.Type() != y.Type() { - return false - } - if x.IsNil() && y.IsNil() { - return true - } - panic("cannot compare functions") - case reflect.String: - return x.String() == y.String() - case reflect.Bool: - return x.Bool() == y.Bool() - // Ints, uints and floats handled in jsonNumber, at top of function. - default: - panic(fmt.Sprintf("unsupported kind: %s", x.Kind())) - } -} - -// hashValue adds v to the data hashed by h. v must not have cycles. -// hashValue panics if the value contains functions or channels, or maps whose -// key type is not string. -// It ignores unexported fields of structs. -// Calls to hashValue with the equal values (in the sense -// of [Equal]) result in the same sequence of values written to the hash. -func hashValue(h *maphash.Hash, v reflect.Value) { - // TODO: replace writes of basic types with WriteComparable in 1.24. - - writeUint := func(u uint64) { - var buf [8]byte - binary.BigEndian.PutUint64(buf[:], u) - h.Write(buf[:]) - } - - var write func(reflect.Value) - write = func(v reflect.Value) { - if r, ok := jsonNumber(v); ok { - // We want 1.0 and 1 to hash the same. - // big.Rats are always normalized, so they will be. - // We could do this more efficiently by handling the int and float cases - // separately, but that's premature. - writeUint(uint64(r.Sign() + 1)) - h.Write(r.Num().Bytes()) - h.Write(r.Denom().Bytes()) - return - } - switch v.Kind() { - case reflect.Invalid: - h.WriteByte(0) - case reflect.String: - h.WriteString(v.String()) - case reflect.Bool: - if v.Bool() { - h.WriteByte(1) - } else { - h.WriteByte(0) - } - case reflect.Complex64, reflect.Complex128: - c := v.Complex() - writeUint(math.Float64bits(real(c))) - writeUint(math.Float64bits(imag(c))) - case reflect.Array, reflect.Slice: - // Although we could treat []byte more efficiently, - // JSON values are unlikely to contain them. - writeUint(uint64(v.Len())) - for i := range v.Len() { - write(v.Index(i)) - } - case reflect.Interface, reflect.Pointer: - write(v.Elem()) - case reflect.Struct: - t := v.Type() - for i := range t.NumField() { - if sf := t.Field(i); sf.IsExported() { - write(v.FieldByIndex(sf.Index)) - } - } - case reflect.Map: - if v.Type().Key().Kind() != reflect.String { - panic("map with non-string key") - } - // Sort the keys so the hash is deterministic. - keys := v.MapKeys() - // Write the length. That distinguishes between, say, two consecutive - // maps with disjoint keys from one map that has the items of both. - writeUint(uint64(len(keys))) - slices.SortFunc(keys, func(x, y reflect.Value) int { return cmp.Compare(x.String(), y.String()) }) - for _, k := range keys { - write(k) - write(v.MapIndex(k)) - } - // Ints, uints and floats handled in jsonNumber, at top of function. - default: - panic(fmt.Sprintf("unsupported kind: %s", v.Kind())) - } - } - - write(v) -} - -// jsonNumber converts a numeric value or a json.Number to a [big.Rat]. -// If v is not a number, it returns nil, false. -func jsonNumber(v reflect.Value) (*big.Rat, bool) { - r := new(big.Rat) - switch { - case !v.IsValid(): - return nil, false - case v.CanInt(): - r.SetInt64(v.Int()) - case v.CanUint(): - r.SetUint64(v.Uint()) - case v.CanFloat(): - r.SetFloat64(v.Float()) - default: - jn, ok := v.Interface().(json.Number) - if !ok { - return nil, false - } - if _, ok := r.SetString(jn.String()); !ok { - // This can fail in rare cases; for example, "1e9999999". - // That is a valid JSON number, since the spec puts no limit on the size - // of the exponent. - return nil, false - } - } - return r, true -} - -// jsonType returns a string describing the type of the JSON value, -// as described in the JSON Schema specification: -// https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.1. -// It returns "", false if the value is not valid JSON. -func jsonType(v reflect.Value) (string, bool) { - if !v.IsValid() { - // Not v.IsNil(): a nil []any is still a JSON array. - return "null", true - } - if v.CanInt() || v.CanUint() { - return "integer", true - } - if v.CanFloat() { - if _, f := math.Modf(v.Float()); f == 0 { - return "integer", true - } - return "number", true - } - switch v.Kind() { - case reflect.Bool: - return "boolean", true - case reflect.String: - return "string", true - case reflect.Slice, reflect.Array: - return "array", true - case reflect.Map, reflect.Struct: - return "object", true - default: - return "", false - } -} - -func assert(cond bool, msg string) { - if !cond { - panic("assertion failed: " + msg) - } -} - -// marshalStructWithMap marshals its first argument to JSON, treating the field named -// mapField as an embedded map. The first argument must be a pointer to -// a struct. The underlying type of mapField must be a map[string]any, and it must have -// a "-" json tag, meaning it will not be marshaled. -// -// For example, given this struct: -// -// type S struct { -// A int -// Extra map[string] any `json:"-"` -// } -// -// and this value: -// -// s := S{A: 1, Extra: map[string]any{"B": 2}} -// -// the call marshalJSONWithMap(s, "Extra") would return -// -// {"A": 1, "B": 2} -// -// It is an error if the map contains the same key as another struct field's -// JSON name. -// -// marshalStructWithMap calls json.Marshal on a value of type T, so T must not -// have a MarshalJSON method that calls this function, on pain of infinite regress. -// -// Note that there is a similar function in mcp/util.go, but they are not the same. -// Here the function requires `-` json tag, does not clear the mapField map, -// and handles embedded struct due to the implementation of jsonNames in this package. -// -// TODO: avoid this restriction on T by forcing it to marshal in a default way. -// See https://go.dev/play/p/EgXKJHxEx_R. -func marshalStructWithMap[T any](s *T, mapField string) ([]byte, error) { - // Marshal the struct and the map separately, and concatenate the bytes. - // This strategy is dramatically less complicated than - // constructing a synthetic struct or map with the combined keys. - if s == nil { - return []byte("null"), nil - } - s2 := *s - vMapField := reflect.ValueOf(&s2).Elem().FieldByName(mapField) - mapVal := vMapField.Interface().(map[string]any) - - // Check for duplicates. - names := jsonNames(reflect.TypeFor[T]()) - for key := range mapVal { - if names[key] { - return nil, fmt.Errorf("map key %q duplicates struct field", key) - } - } - - structBytes, err := json.Marshal(s2) - if err != nil { - return nil, fmt.Errorf("marshalStructWithMap(%+v): %w", s, err) - } - if len(mapVal) == 0 { - return structBytes, nil - } - mapBytes, err := json.Marshal(mapVal) - if err != nil { - return nil, err - } - if len(structBytes) == 2 { // must be "{}" - return mapBytes, nil - } - // "{X}" + "{Y}" => "{X,Y}" - res := append(structBytes[:len(structBytes)-1], ',') - res = append(res, mapBytes[1:]...) - return res, nil -} - -// unmarshalStructWithMap is the inverse of marshalStructWithMap. -// T has the same restrictions as in that function. -// -// Note that there is a similar function in mcp/util.go, but they are not the same. -// Here jsonNames also returns fields from embedded structs, hence this function -// handles embedded structs as well. -func unmarshalStructWithMap[T any](data []byte, v *T, mapField string) error { - // Unmarshal into the struct, ignoring unknown fields. - if err := json.Unmarshal(data, v); err != nil { - return err - } - // Unmarshal into the map. - m := map[string]any{} - if err := json.Unmarshal(data, &m); err != nil { - return err - } - // Delete from the map the fields of the struct. - for n := range jsonNames(reflect.TypeFor[T]()) { - delete(m, n) - } - if len(m) != 0 { - reflect.ValueOf(v).Elem().FieldByName(mapField).Set(reflect.ValueOf(m)) - } - return nil -} - -var jsonNamesMap sync.Map // from reflect.Type to map[string]bool - -// jsonNames returns the set of JSON object keys that t will marshal into, -// including fields from embedded structs in t. -// t must be a struct type. -// -// Note that there is a similar function in mcp/util.go, but they are not the same -// Here the function recurses over embedded structs and includes fields from them. -func jsonNames(t reflect.Type) map[string]bool { - // Lock not necessary: at worst we'll duplicate work. - if val, ok := jsonNamesMap.Load(t); ok { - return val.(map[string]bool) - } - m := map[string]bool{} - for i := range t.NumField() { - field := t.Field(i) - // handle embedded structs - if field.Anonymous { - fieldType := field.Type - if fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - for n := range jsonNames(fieldType) { - m[n] = true - } - continue - } - info := fieldJSONInfo(field) - if !info.omit { - m[info.name] = true - } - } - jsonNamesMap.Store(t, m) - return m -} - -type jsonInfo struct { - omit bool // unexported or first tag element is "-" - name string // Go field name or first tag element. Empty if omit is true. - settings map[string]bool // "omitempty", "omitzero", etc. -} - -// fieldJSONInfo reports information about how encoding/json -// handles the given struct field. -// If the field is unexported, jsonInfo.omit is true and no other jsonInfo field -// is populated. -// If the field is exported and has no tag, then name is the field's name and all -// other fields are false. -// Otherwise, the information is obtained from the tag. -func fieldJSONInfo(f reflect.StructField) jsonInfo { - if !f.IsExported() { - return jsonInfo{omit: true} - } - info := jsonInfo{name: f.Name} - if tag, ok := f.Tag.Lookup("json"); ok { - name, rest, found := strings.Cut(tag, ",") - // "-" means omit, but "-," means the name is "-" - if name == "-" && !found { - return jsonInfo{omit: true} - } - if name != "" { - info.name = name - } - if len(rest) > 0 { - info.settings = map[string]bool{} - for _, s := range strings.Split(rest, ",") { - info.settings[s] = true - } - } - } - return info -} - -// wrapf wraps *errp with the given formatted message if *errp is not nil. -func wrapf(errp *error, format string, args ...any) { - if *errp != nil { - *errp = fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), *errp) - } -} diff --git a/jsonschema/util_test.go b/jsonschema/util_test.go deleted file mode 100644 index 7934bff7..00000000 --- a/jsonschema/util_test.go +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "encoding/json" - "hash/maphash" - "reflect" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" -) - -func TestEqual(t *testing.T) { - for _, tt := range []struct { - x1, x2 any - want bool - }{ - {0, 1, false}, - {1, 1.0, true}, - {nil, 0, false}, - {"0", 0, false}, - {2.5, 2.5, true}, - {[]int{1, 2}, []float64{1.0, 2.0}, true}, - {[]int(nil), []int{}, false}, - {[]map[string]any(nil), []map[string]any{}, false}, - { - map[string]any{"a": 1, "b": 2.0}, - map[string]any{"a": 1.0, "b": 2}, - true, - }, - } { - check := func(x1, x2 any, want bool) { - t.Helper() - if got := Equal(x1, x2); got != want { - t.Errorf("jsonEqual(%#v, %#v) = %t, want %t", x1, x2, got, want) - } - } - check(tt.x1, tt.x1, true) - check(tt.x2, tt.x2, true) - check(tt.x1, tt.x2, tt.want) - check(tt.x2, tt.x1, tt.want) - } -} - -func TestJSONType(t *testing.T) { - for _, tt := range []struct { - val string - want string - }{ - {`null`, "null"}, - {`0`, "integer"}, - {`0.0`, "integer"}, - {`1e2`, "integer"}, - {`0.1`, "number"}, - {`""`, "string"}, - {`true`, "boolean"}, - {`[]`, "array"}, - {`{}`, "object"}, - } { - var val any - if err := json.Unmarshal([]byte(tt.val), &val); err != nil { - t.Fatal(err) - } - got, ok := jsonType(reflect.ValueOf(val)) - if !ok { - t.Fatalf("jsonType failed on %q", tt.val) - } - if got != tt.want { - t.Errorf("%s: got %q, want %q", tt.val, got, tt.want) - } - - } -} - -func TestHash(t *testing.T) { - x := map[string]any{ - "s": []any{1, "foo", nil, true}, - "f": 2.5, - "m": map[string]any{ - "n": json.Number("123.456"), - "schema": &Schema{Type: "integer", UniqueItems: true}, - }, - "c": 1.2 + 3.4i, - "n": nil, - } - - seed := maphash.MakeSeed() - - hash := func(x any) uint64 { - var h maphash.Hash - h.SetSeed(seed) - hashValue(&h, reflect.ValueOf(x)) - return h.Sum64() - } - - want := hash(x) - // Run several times to verify consistency. - for range 10 { - if got := hash(x); got != want { - t.Errorf("hash values differ: %d vs. %d", got, want) - } - } - - // Check mathematically equal values. - nums := []any{ - 5, - uint(5), - 5.0, - json.Number("5"), - json.Number("5.00"), - } - for i, n := range nums { - if i == 0 { - want = hash(n) - } else if got := hash(n); got != want { - t.Errorf("hashes differ between %v (%[1]T) and %v (%[2]T)", nums[0], n) - } - } - - // Check that a bare JSON `null` is OK. - var null any - if err := json.Unmarshal([]byte(`null`), &null); err != nil { - t.Fatal(err) - } - _ = hash(null) -} - -func TestMarshalStructWithMap(t *testing.T) { - type S struct { - A int - B string `json:"b,omitempty"` - u bool - M map[string]any `json:"-"` - } - t.Run("basic", func(t *testing.T) { - s := S{A: 1, B: "two", M: map[string]any{"!@#": true}} - got, err := marshalStructWithMap(&s, "M") - if err != nil { - t.Fatal(err) - } - want := `{"A":1,"b":"two","!@#":true}` - if g := string(got); g != want { - t.Errorf("\ngot %s\nwant %s", g, want) - } - - var un S - if err := unmarshalStructWithMap(got, &un, "M"); err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(s, un, cmpopts.IgnoreUnexported(S{})); diff != "" { - t.Errorf("mismatch (-want, +got):\n%s", diff) - } - }) - t.Run("duplicate", func(t *testing.T) { - s := S{A: 1, B: "two", M: map[string]any{"b": "dup"}} - _, err := marshalStructWithMap(&s, "M") - if err == nil || !strings.Contains(err.Error(), "duplicate") { - t.Errorf("got %v, want error with 'duplicate'", err) - } - }) - t.Run("embedded", func(t *testing.T) { - type Embedded struct { - A int - B int - Extra map[string]any `json:"-"` - } - type S struct { - C int - Embedded - } - s := S{C: 1, Embedded: Embedded{A: 2, B: 3, Extra: map[string]any{"d": 4, "e": 5}}} - got, err := marshalStructWithMap(&s, "Extra") - if err != nil { - t.Fatal(err) - } - want := `{"C":1,"A":2,"B":3,"d":4,"e":5}` - if g := string(got); g != want { - t.Errorf("got %v, want %v", g, want) - } - }) -} - -func TestJSONInfo(t *testing.T) { - type S struct { - A int - B int `json:","` - C int `json:"-"` - D int `json:"-,"` - E int `json:"echo"` - F int `json:"foxtrot,omitempty"` - g int `json:"golf"` - } - want := []jsonInfo{ - {name: "A"}, - {name: "B"}, - {omit: true}, - {name: "-"}, - {name: "echo"}, - {name: "foxtrot", settings: map[string]bool{"omitempty": true}}, - {omit: true}, - } - tt := reflect.TypeFor[S]() - for i := range tt.NumField() { - got := fieldJSONInfo(tt.Field(i)) - if !reflect.DeepEqual(got, want[i]) { - t.Errorf("got %+v, want %+v", got, want[i]) - } - } -} diff --git a/jsonschema/validate.go b/jsonschema/validate.go deleted file mode 100644 index 99ddd3b8..00000000 --- a/jsonschema/validate.go +++ /dev/null @@ -1,758 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "encoding/json" - "fmt" - "hash/maphash" - "iter" - "math" - "math/big" - "reflect" - "slices" - "strings" - "sync" - "unicode/utf8" -) - -// The value of the "$schema" keyword for the version that we can validate. -const draft202012 = "https://json-schema.org/draft/2020-12/schema" - -// Validate validates the instance, which must be a JSON value, against the schema. -// It returns nil if validation is successful or an error if it is not. -// If the schema type is "object", instance can be a map[string]any or a struct. -func (rs *Resolved) Validate(instance any) error { - if s := rs.root.Schema; s != "" && s != draft202012 { - return fmt.Errorf("cannot validate version %s, only %s", s, draft202012) - } - st := &state{rs: rs} - return st.validate(reflect.ValueOf(instance), st.rs.root, nil) -} - -// validateDefaults walks the schema tree. If it finds a default, it validates it -// against the schema containing it. -// -// TODO(jba): account for dynamic refs. This algorithm simple-mindedly -// treats each schema with a default as its own root. -func (rs *Resolved) validateDefaults() error { - if s := rs.root.Schema; s != "" && s != draft202012 { - return fmt.Errorf("cannot validate version %s, only %s", s, draft202012) - } - st := &state{rs: rs} - for s := range rs.root.all() { - // We checked for nil schemas in [Schema.Resolve]. - assert(s != nil, "nil schema") - if s.DynamicRef != "" { - return fmt.Errorf("jsonschema: %s: validateDefaults does not support dynamic refs", rs.schemaString(s)) - } - if s.Default != nil { - var d any - if err := json.Unmarshal(s.Default, &d); err != nil { - return fmt.Errorf("unmarshaling default value of schema %s: %w", rs.schemaString(s), err) - } - if err := st.validate(reflect.ValueOf(d), s, nil); err != nil { - return err - } - } - } - return nil -} - -// state is the state of single call to ResolvedSchema.Validate. -type state struct { - rs *Resolved - // stack holds the schemas from recursive calls to validate. - // These are the "dynamic scopes" used to resolve dynamic references. - // https://json-schema.org/draft/2020-12/json-schema-core#scopes - stack []*Schema -} - -// validate validates the reflected value of the instance. -func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *annotations) (err error) { - defer wrapf(&err, "validating %s", st.rs.schemaString(schema)) - - // Maintain a stack for dynamic schema resolution. - st.stack = append(st.stack, schema) // push - defer func() { - st.stack = st.stack[:len(st.stack)-1] // pop - }() - - // We checked for nil schemas in [Schema.Resolve]. - assert(schema != nil, "nil schema") - - // Step through interfaces and pointers. - for instance.Kind() == reflect.Pointer || instance.Kind() == reflect.Interface { - instance = instance.Elem() - } - - schemaInfo := st.rs.resolvedInfos[schema] - - // type: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.1 - if schema.Type != "" || schema.Types != nil { - gotType, ok := jsonType(instance) - if !ok { - return fmt.Errorf("type: %v of type %[1]T is not a valid JSON value", instance) - } - if schema.Type != "" { - // "number" subsumes integers - if !(gotType == schema.Type || - gotType == "integer" && schema.Type == "number") { - return fmt.Errorf("type: %v has type %q, want %q", instance, gotType, schema.Type) - } - } else { - if !(slices.Contains(schema.Types, gotType) || (gotType == "integer" && slices.Contains(schema.Types, "number"))) { - return fmt.Errorf("type: %v has type %q, want one of %q", - instance, gotType, strings.Join(schema.Types, ", ")) - } - } - } - // enum: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.2 - if schema.Enum != nil { - ok := false - for _, e := range schema.Enum { - if equalValue(reflect.ValueOf(e), instance) { - ok = true - break - } - } - if !ok { - return fmt.Errorf("enum: %v does not equal any of: %v", instance, schema.Enum) - } - } - - // const: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.3 - if schema.Const != nil { - if !equalValue(reflect.ValueOf(*schema.Const), instance) { - return fmt.Errorf("const: %v does not equal %v", instance, *schema.Const) - } - } - - // numbers: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.2 - if schema.MultipleOf != nil || schema.Minimum != nil || schema.Maximum != nil || schema.ExclusiveMinimum != nil || schema.ExclusiveMaximum != nil { - n, ok := jsonNumber(instance) - if ok { // these keywords don't apply to non-numbers - if schema.MultipleOf != nil { - // TODO: validate MultipleOf as non-zero. - // The test suite assumes floats. - nf, _ := n.Float64() // don't care if it's exact or not - if _, f := math.Modf(nf / *schema.MultipleOf); f != 0 { - return fmt.Errorf("multipleOf: %s is not a multiple of %f", n, *schema.MultipleOf) - } - } - - m := new(big.Rat) // reuse for all of the following - cmp := func(f float64) int { return n.Cmp(m.SetFloat64(f)) } - - if schema.Minimum != nil && cmp(*schema.Minimum) < 0 { - return fmt.Errorf("minimum: %s is less than %f", n, *schema.Minimum) - } - if schema.Maximum != nil && cmp(*schema.Maximum) > 0 { - return fmt.Errorf("maximum: %s is greater than %f", n, *schema.Maximum) - } - if schema.ExclusiveMinimum != nil && cmp(*schema.ExclusiveMinimum) <= 0 { - return fmt.Errorf("exclusiveMinimum: %s is less than or equal to %f", n, *schema.ExclusiveMinimum) - } - if schema.ExclusiveMaximum != nil && cmp(*schema.ExclusiveMaximum) >= 0 { - return fmt.Errorf("exclusiveMaximum: %s is greater than or equal to %f", n, *schema.ExclusiveMaximum) - } - } - } - - // strings: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.3 - if instance.Kind() == reflect.String && (schema.MinLength != nil || schema.MaxLength != nil || schema.Pattern != "") { - str := instance.String() - n := utf8.RuneCountInString(str) - if schema.MinLength != nil { - if m := *schema.MinLength; n < m { - return fmt.Errorf("minLength: %q contains %d Unicode code points, fewer than %d", str, n, m) - } - } - if schema.MaxLength != nil { - if m := *schema.MaxLength; n > m { - return fmt.Errorf("maxLength: %q contains %d Unicode code points, more than %d", str, n, m) - } - } - - if schema.Pattern != "" && !schemaInfo.pattern.MatchString(str) { - return fmt.Errorf("pattern: %q does not match regular expression %q", str, schema.Pattern) - } - } - - var anns annotations // all the annotations for this call and child calls - - // $ref: https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.1 - if schema.Ref != "" { - if err := st.validate(instance, schemaInfo.resolvedRef, &anns); err != nil { - return err - } - } - - // $dynamicRef: https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.2 - if schema.DynamicRef != "" { - // The ref behaves lexically or dynamically, but not both. - assert((schemaInfo.resolvedDynamicRef == nil) != (schemaInfo.dynamicRefAnchor == ""), - "DynamicRef not resolved properly") - if schemaInfo.resolvedDynamicRef != nil { - // Same as $ref. - if err := st.validate(instance, schemaInfo.resolvedDynamicRef, &anns); err != nil { - return err - } - } else { - // Dynamic behavior. - // Look for the base of the outermost schema on the stack with this dynamic - // anchor. (Yes, outermost: the one farthest from here. This the opposite - // of how ordinary dynamic variables behave.) - // Why the base of the schema being validated and not the schema itself? - // Because the base is the scope for anchors. In fact it's possible to - // refer to a schema that is not on the stack, but a child of some base - // on the stack. - // For an example, search for "detached" in testdata/draft2020-12/dynamicRef.json. - var dynamicSchema *Schema - for _, s := range st.stack { - base := st.rs.resolvedInfos[s].base - info, ok := st.rs.resolvedInfos[base].anchors[schemaInfo.dynamicRefAnchor] - if ok && info.dynamic { - dynamicSchema = info.schema - break - } - } - if dynamicSchema == nil { - return fmt.Errorf("missing dynamic anchor %q", schemaInfo.dynamicRefAnchor) - } - if err := st.validate(instance, dynamicSchema, &anns); err != nil { - return err - } - } - } - - // logic - // https://json-schema.org/draft/2020-12/json-schema-core#section-10.2 - // These must happen before arrays and objects because if they evaluate an item or property, - // then the unevaluatedItems/Properties schemas don't apply to it. - // See https://json-schema.org/draft/2020-12/json-schema-core#section-11.2, paragraph 4. - // - // If any of these fail, then validation fails, even if there is an unevaluatedXXX - // keyword in the schema. The spec is unclear about this, but that is the intention. - - valid := func(s *Schema, anns *annotations) bool { return st.validate(instance, s, anns) == nil } - - if schema.AllOf != nil { - for _, ss := range schema.AllOf { - if err := st.validate(instance, ss, &anns); err != nil { - return err - } - } - } - if schema.AnyOf != nil { - // We must visit them all, to collect annotations. - ok := false - for _, ss := range schema.AnyOf { - if valid(ss, &anns) { - ok = true - } - } - if !ok { - return fmt.Errorf("anyOf: did not validate against any of %v", schema.AnyOf) - } - } - if schema.OneOf != nil { - // Exactly one. - var okSchema *Schema - for _, ss := range schema.OneOf { - if valid(ss, &anns) { - if okSchema != nil { - return fmt.Errorf("oneOf: validated against both %v and %v", okSchema, ss) - } - okSchema = ss - } - } - if okSchema == nil { - return fmt.Errorf("oneOf: did not validate against any of %v", schema.OneOf) - } - } - if schema.Not != nil { - // Ignore annotations from "not". - if valid(schema.Not, nil) { - return fmt.Errorf("not: validated against %v", schema.Not) - } - } - if schema.If != nil { - var ss *Schema - if valid(schema.If, &anns) { - ss = schema.Then - } else { - ss = schema.Else - } - if ss != nil { - if err := st.validate(instance, ss, &anns); err != nil { - return err - } - } - } - - // arrays - // TODO(jba): consider arrays of structs. - if instance.Kind() == reflect.Array || instance.Kind() == reflect.Slice { - // https://json-schema.org/draft/2020-12/json-schema-core#section-10.3.1 - // This validate call doesn't collect annotations for the items of the instance; they are separate - // instances in their own right. - // TODO(jba): if the test suite doesn't cover this case, add a test. For example, nested arrays. - for i, ischema := range schema.PrefixItems { - if i >= instance.Len() { - break // shorter is OK - } - if err := st.validate(instance.Index(i), ischema, nil); err != nil { - return err - } - } - anns.noteEndIndex(min(len(schema.PrefixItems), instance.Len())) - - if schema.Items != nil { - for i := len(schema.PrefixItems); i < instance.Len(); i++ { - if err := st.validate(instance.Index(i), schema.Items, nil); err != nil { - return err - } - } - // Note that all the items in this array have been validated. - anns.allItems = true - } - - nContains := 0 - if schema.Contains != nil { - for i := range instance.Len() { - if err := st.validate(instance.Index(i), schema.Contains, nil); err == nil { - nContains++ - anns.noteIndex(i) - } - } - if nContains == 0 && (schema.MinContains == nil || *schema.MinContains > 0) { - return fmt.Errorf("contains: %s does not have an item matching %s", instance, schema.Contains) - } - } - - // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.4 - // TODO(jba): check that these next four keywords' values are integers. - if schema.MinContains != nil && schema.Contains != nil { - if m := *schema.MinContains; nContains < m { - return fmt.Errorf("minContains: contains validated %d items, less than %d", nContains, m) - } - } - if schema.MaxContains != nil && schema.Contains != nil { - if m := *schema.MaxContains; nContains > m { - return fmt.Errorf("maxContains: contains validated %d items, greater than %d", nContains, m) - } - } - if schema.MinItems != nil { - if m := *schema.MinItems; instance.Len() < m { - return fmt.Errorf("minItems: array length %d is less than %d", instance.Len(), m) - } - } - if schema.MaxItems != nil { - if m := *schema.MaxItems; instance.Len() > m { - return fmt.Errorf("maxItems: array length %d is greater than %d", instance.Len(), m) - } - } - if schema.UniqueItems { - if instance.Len() > 1 { - // Hash each item and compare the hashes. - // If two hashes differ, the items differ. - // If two hashes are the same, compare the collisions for equality. - // (The same logic as hash table lookup.) - // TODO(jba): Use container/hash.Map when it becomes available (https://go.dev/issue/69559), - hashes := map[uint64][]int{} // from hash to indices - seed := maphash.MakeSeed() - for i := range instance.Len() { - item := instance.Index(i) - var h maphash.Hash - h.SetSeed(seed) - hashValue(&h, item) - hv := h.Sum64() - if sames := hashes[hv]; len(sames) > 0 { - for _, j := range sames { - if equalValue(item, instance.Index(j)) { - return fmt.Errorf("uniqueItems: array items %d and %d are equal", i, j) - } - } - } - hashes[hv] = append(hashes[hv], i) - } - } - } - - // https://json-schema.org/draft/2020-12/json-schema-core#section-11.2 - if schema.UnevaluatedItems != nil && !anns.allItems { - // Apply this subschema to all items in the array that haven't been successfully validated. - // That includes validations by subschemas on the same instance, like allOf. - for i := anns.endIndex; i < instance.Len(); i++ { - if !anns.evaluatedIndexes[i] { - if err := st.validate(instance.Index(i), schema.UnevaluatedItems, nil); err != nil { - return err - } - } - } - anns.allItems = true - } - } - - // objects - // https://json-schema.org/draft/2020-12/json-schema-core#section-10.3.2 - if instance.Kind() == reflect.Map || instance.Kind() == reflect.Struct { - if instance.Kind() == reflect.Map { - if kt := instance.Type().Key(); kt.Kind() != reflect.String { - return fmt.Errorf("map key type %s is not a string", kt) - } - } - // Track the evaluated properties for just this schema, to support additionalProperties. - // If we used anns here, then we'd be including properties evaluated in subschemas - // from allOf, etc., which additionalProperties shouldn't observe. - evalProps := map[string]bool{} - for prop, subschema := range schema.Properties { - val := property(instance, prop) - if !val.IsValid() { - // It's OK if the instance doesn't have the property. - continue - } - // If the instance is a struct and an optional property has the zero - // value, then we could interpret it as present or missing. Be generous: - // assume it's missing, and thus always validates successfully. - if instance.Kind() == reflect.Struct && val.IsZero() && !schemaInfo.isRequired[prop] { - continue - } - if err := st.validate(val, subschema, nil); err != nil { - return err - } - evalProps[prop] = true - } - if len(schema.PatternProperties) > 0 { - for prop, val := range properties(instance) { - // Check every matching pattern. - for re, schema := range schemaInfo.patternProperties { - if re.MatchString(prop) { - if err := st.validate(val, schema, nil); err != nil { - return err - } - evalProps[prop] = true - } - } - } - } - if schema.AdditionalProperties != nil { - // Apply to all properties not handled above. - for prop, val := range properties(instance) { - if !evalProps[prop] { - if err := st.validate(val, schema.AdditionalProperties, nil); err != nil { - return err - } - evalProps[prop] = true - } - } - } - anns.noteProperties(evalProps) - if schema.PropertyNames != nil { - // Note: properties unnecessarily fetches each value. We could define a propertyNames function - // if performance ever matters. - for prop := range properties(instance) { - if err := st.validate(reflect.ValueOf(prop), schema.PropertyNames, nil); err != nil { - return err - } - } - } - - // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.5 - var min, max int - if schema.MinProperties != nil || schema.MaxProperties != nil { - min, max = numPropertiesBounds(instance, schemaInfo.isRequired) - } - if schema.MinProperties != nil { - if n, m := max, *schema.MinProperties; n < m { - return fmt.Errorf("minProperties: object has %d properties, less than %d", n, m) - } - } - if schema.MaxProperties != nil { - if n, m := min, *schema.MaxProperties; n > m { - return fmt.Errorf("maxProperties: object has %d properties, greater than %d", n, m) - } - } - - hasProperty := func(prop string) bool { - return property(instance, prop).IsValid() - } - - missingProperties := func(props []string) []string { - var missing []string - for _, p := range props { - if !hasProperty(p) { - missing = append(missing, p) - } - } - return missing - } - - if schema.Required != nil { - if m := missingProperties(schema.Required); len(m) > 0 { - return fmt.Errorf("required: missing properties: %q", m) - } - } - if schema.DependentRequired != nil { - // "Validation succeeds if, for each name that appears in both the instance - // and as a name within this keyword's value, every item in the corresponding - // array is also the name of a property in the instance." §6.5.4 - for dprop, reqs := range schema.DependentRequired { - if hasProperty(dprop) { - if m := missingProperties(reqs); len(m) > 0 { - return fmt.Errorf("dependentRequired[%q]: missing properties %q", dprop, m) - } - } - } - } - - // https://json-schema.org/draft/2020-12/json-schema-core#section-10.2.2.4 - if schema.DependentSchemas != nil { - // This does not collect annotations, although it seems like it should. - for dprop, ss := range schema.DependentSchemas { - if hasProperty(dprop) { - // TODO: include dependentSchemas[dprop] in the errors. - err := st.validate(instance, ss, &anns) - if err != nil { - return err - } - } - } - } - if schema.UnevaluatedProperties != nil && !anns.allProperties { - // This looks a lot like AdditionalProperties, but depends on in-place keywords like allOf - // in addition to sibling keywords. - for prop, val := range properties(instance) { - if !anns.evaluatedProperties[prop] { - if err := st.validate(val, schema.UnevaluatedProperties, nil); err != nil { - return err - } - } - } - // The spec says the annotation should be the set of evaluated properties, but we can optimize - // by setting a single boolean, since after this succeeds all properties will be validated. - // See https://json-schema.slack.com/archives/CT7FF623C/p1745592564381459. - anns.allProperties = true - } - } - - if callerAnns != nil { - // Our caller wants to know what we've validated. - callerAnns.merge(&anns) - } - return nil -} - -// resolveDynamicRef returns the schema referred to by the argument schema's -// $dynamicRef value. -// It returns an error if the dynamic reference has no referent. -// If there is no $dynamicRef, resolveDynamicRef returns nil, nil. -// See https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.2. -func (st *state) resolveDynamicRef(schema *Schema) (*Schema, error) { - if schema.DynamicRef == "" { - return nil, nil - } - info := st.rs.resolvedInfos[schema] - // The ref behaves lexically or dynamically, but not both. - assert((info.resolvedDynamicRef == nil) != (info.dynamicRefAnchor == ""), - "DynamicRef not statically resolved properly") - if r := info.resolvedDynamicRef; r != nil { - // Same as $ref. - return r, nil - } - // Dynamic behavior. - // Look for the base of the outermost schema on the stack with this dynamic - // anchor. (Yes, outermost: the one farthest from here. This the opposite - // of how ordinary dynamic variables behave.) - // Why the base of the schema being validated and not the schema itself? - // Because the base is the scope for anchors. In fact it's possible to - // refer to a schema that is not on the stack, but a child of some base - // on the stack. - // For an example, search for "detached" in testdata/draft2020-12/dynamicRef.json. - for _, s := range st.stack { - base := st.rs.resolvedInfos[s].base - info, ok := st.rs.resolvedInfos[base].anchors[info.dynamicRefAnchor] - if ok && info.dynamic { - return info.schema, nil - } - } - return nil, fmt.Errorf("missing dynamic anchor %q", info.dynamicRefAnchor) -} - -// ApplyDefaults modifies an instance by applying the schema's defaults to it. If -// a schema or sub-schema has a default, then a corresponding zero instance value -// is set to the default. -// -// The JSON Schema specification does not describe how defaults should be interpreted. -// This method honors defaults only on properties, and only those that are not required. -// If the instance is a map and the property is missing, the property is added to -// the map with the default. -// If the instance is a struct, the field corresponding to the property exists, and -// its value is zero, the field is set to the default. -// ApplyDefaults can panic if a default cannot be assigned to a field. -// -// The argument must be a pointer to the instance. -// (In case we decide that top-level defaults are meaningful.) -// -// It is recommended to first call Resolve with a ValidateDefaults option of true, -// then call this method, and lastly call Validate. -// -// TODO(jba): consider what defaults on top-level or array instances might mean. -// TODO(jba): follow $ref and $dynamicRef -// TODO(jba): apply defaults on sub-schemas to corresponding sub-instances. -func (rs *Resolved) ApplyDefaults(instancep any) error { - st := &state{rs: rs} - return st.applyDefaults(reflect.ValueOf(instancep), rs.root) -} - -// Leave this as a potentially recursive helper function, because we'll surely want -// to apply defaults on sub-schemas someday. -func (st *state) applyDefaults(instancep reflect.Value, schema *Schema) (err error) { - defer wrapf(&err, "applyDefaults: schema %s, instance %v", st.rs.schemaString(schema), instancep) - - schemaInfo := st.rs.resolvedInfos[schema] - instance := instancep.Elem() - if instance.Kind() == reflect.Map || instance.Kind() == reflect.Struct { - if instance.Kind() == reflect.Map { - if kt := instance.Type().Key(); kt.Kind() != reflect.String { - return fmt.Errorf("map key type %s is not a string", kt) - } - } - for prop, subschema := range schema.Properties { - // Ignore defaults on required properties. (A required property shouldn't have a default.) - if schemaInfo.isRequired[prop] { - continue - } - val := property(instance, prop) - switch instance.Kind() { - case reflect.Map: - // If there is a default for this property, and the map key is missing, - // set the map value to the default. - if subschema.Default != nil && !val.IsValid() { - // Create an lvalue, since map values aren't addressable. - lvalue := reflect.New(instance.Type().Elem()) - if err := json.Unmarshal(subschema.Default, lvalue.Interface()); err != nil { - return err - } - instance.SetMapIndex(reflect.ValueOf(prop), lvalue.Elem()) - } - case reflect.Struct: - // If there is a default for this property, and the field exists but is zero, - // set the field to the default. - if subschema.Default != nil && val.IsValid() && val.IsZero() { - if err := json.Unmarshal(subschema.Default, val.Addr().Interface()); err != nil { - return err - } - } - default: - panic(fmt.Sprintf("applyDefaults: property %s: bad value %s of kind %s", - prop, instance, instance.Kind())) - } - } - } - return nil -} - -// property returns the value of the property of v with the given name, or the invalid -// reflect.Value if there is none. -// If v is a map, the property is the value of the map whose key is name. -// If v is a struct, the property is the value of the field with the given name according -// to the encoding/json package (see [jsonName]). -// If v is anything else, property panics. -func property(v reflect.Value, name string) reflect.Value { - switch v.Kind() { - case reflect.Map: - return v.MapIndex(reflect.ValueOf(name)) - case reflect.Struct: - props := structPropertiesOf(v.Type()) - // Ignore nonexistent properties. - if sf, ok := props[name]; ok { - return v.FieldByIndex(sf.Index) - } - return reflect.Value{} - default: - panic(fmt.Sprintf("property(%q): bad value %s of kind %s", name, v, v.Kind())) - } -} - -// properties returns an iterator over the names and values of all properties -// in v, which must be a map or a struct. -// If a struct, zero-valued properties that are marked omitempty or omitzero -// are excluded. -func properties(v reflect.Value) iter.Seq2[string, reflect.Value] { - return func(yield func(string, reflect.Value) bool) { - switch v.Kind() { - case reflect.Map: - for k, e := range v.Seq2() { - if !yield(k.String(), e) { - return - } - } - case reflect.Struct: - for name, sf := range structPropertiesOf(v.Type()) { - val := v.FieldByIndex(sf.Index) - if val.IsZero() { - info := fieldJSONInfo(sf) - if info.settings["omitempty"] || info.settings["omitzero"] { - continue - } - } - if !yield(name, val) { - return - } - } - default: - panic(fmt.Sprintf("bad value %s of kind %s", v, v.Kind())) - } - } -} - -// numPropertiesBounds returns bounds on the number of v's properties. -// v must be a map or a struct. -// If v is a map, both bounds are the map's size. -// If v is a struct, the max is the number of struct properties. -// But since we don't know whether a zero value indicates a missing optional property -// or not, be generous and use the number of non-zero properties as the min. -func numPropertiesBounds(v reflect.Value, isRequired map[string]bool) (int, int) { - switch v.Kind() { - case reflect.Map: - return v.Len(), v.Len() - case reflect.Struct: - sp := structPropertiesOf(v.Type()) - min := 0 - for prop, sf := range sp { - if !v.FieldByIndex(sf.Index).IsZero() || isRequired[prop] { - min++ - } - } - return min, len(sp) - default: - panic(fmt.Sprintf("properties: bad value: %s of kind %s", v, v.Kind())) - } -} - -// A propertyMap is a map from property name to struct field index. -type propertyMap = map[string]reflect.StructField - -var structProperties sync.Map // from reflect.Type to propertyMap - -// structPropertiesOf returns the JSON Schema properties for the struct type t. -// The caller must not mutate the result. -func structPropertiesOf(t reflect.Type) propertyMap { - // Mutex not necessary: at worst we'll recompute the same value. - if props, ok := structProperties.Load(t); ok { - return props.(propertyMap) - } - props := map[string]reflect.StructField{} - for _, sf := range reflect.VisibleFields(t) { - info := fieldJSONInfo(sf) - if !info.omit { - props[info.name] = sf - } - } - structProperties.Store(t, props) - return props -} diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go deleted file mode 100644 index 7fb52d18..00000000 --- a/jsonschema/validate_test.go +++ /dev/null @@ -1,294 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "encoding/json" - "fmt" - "net/url" - "os" - "path/filepath" - "reflect" - "strings" - "testing" -) - -// The test for validation uses the official test suite, expressed as a set of JSON files. -// Each file is an array of group objects. - -// A testGroup consists of a schema and some tests on it. -type testGroup struct { - Description string - Schema *Schema - Tests []test -} - -// A test consists of a JSON instance to be validated and the expected result. -type test struct { - Description string - Data any - Valid bool -} - -func TestValidate(t *testing.T) { - files, err := filepath.Glob(filepath.FromSlash("testdata/draft2020-12/*.json")) - if err != nil { - t.Fatal(err) - } - if len(files) == 0 { - t.Fatal("no files") - } - for _, file := range files { - base := filepath.Base(file) - t.Run(base, func(t *testing.T) { - data, err := os.ReadFile(file) - if err != nil { - t.Fatal(err) - } - var groups []testGroup - if err := json.Unmarshal(data, &groups); err != nil { - t.Fatal(err) - } - for _, g := range groups { - t.Run(g.Description, func(t *testing.T) { - rs, err := g.Schema.Resolve(&ResolveOptions{Loader: loadRemote}) - if err != nil { - t.Fatal(err) - } - for _, test := range g.Tests { - t.Run(test.Description, func(t *testing.T) { - err = rs.Validate(test.Data) - if err != nil && test.Valid { - t.Errorf("wanted success, but failed with: %v", err) - } - if err == nil && !test.Valid { - t.Error("succeeded but wanted failure") - } - if t.Failed() { - t.Errorf("schema: %s", g.Schema.json()) - t.Fatalf("instance: %v (%[1]T)", test.Data) - } - }) - } - }) - } - }) - } -} - -func TestValidateErrors(t *testing.T) { - schema := &Schema{ - PrefixItems: []*Schema{{Contains: &Schema{Type: "integer"}}}, - } - rs, err := schema.Resolve(nil) - if err != nil { - t.Fatal(err) - } - err = rs.Validate([]any{[]any{"1"}}) - want := "prefixItems/0" - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("error:\n%s\ndoes not contain %q", err, want) - } -} - -func TestValidateDefaults(t *testing.T) { - s := &Schema{ - Properties: map[string]*Schema{ - "a": {Type: "integer", Default: mustMarshal(1)}, - "b": {Type: "string", Default: mustMarshal("s")}, - }, - Default: mustMarshal(map[string]any{"a": 1, "b": "two"}), - } - if _, err := s.Resolve(&ResolveOptions{ValidateDefaults: true}); err != nil { - t.Fatal(err) - } - - s = &Schema{ - Properties: map[string]*Schema{ - "a": {Type: "integer", Default: mustMarshal(3)}, - "b": {Type: "string", Default: mustMarshal("s")}, - }, - Default: mustMarshal(map[string]any{"a": 1, "b": 2}), - } - _, err := s.Resolve(&ResolveOptions{ValidateDefaults: true}) - want := `has type "integer", want "string"` - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("Resolve returned error %q, want %q", err, want) - } -} - -func TestApplyDefaults(t *testing.T) { - schema := &Schema{ - Properties: map[string]*Schema{ - "A": {Default: mustMarshal(1)}, - "B": {Default: mustMarshal(2)}, - "C": {Default: mustMarshal(3)}, - }, - Required: []string{"C"}, - } - rs, err := schema.Resolve(&ResolveOptions{ValidateDefaults: true}) - if err != nil { - t.Fatal(err) - } - - type S struct{ A, B, C int } - for _, tt := range []struct { - instancep any // pointer to instance value - want any // desired value (not a pointer) - }{ - { - &map[string]any{"B": 0}, - map[string]any{ - "A": float64(1), // filled from default - "B": 0, // untouched: it was already there - // "C" not added: it is required (Validate will catch that) - }, - }, - { - &S{B: 1}, - S{ - A: 1, // filled from default - B: 1, // untouched: non-zero - C: 0, // untouched: required - }, - }, - } { - if err := rs.ApplyDefaults(tt.instancep); err != nil { - t.Fatal(err) - } - got := reflect.ValueOf(tt.instancep).Elem().Interface() // dereference the pointer - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("\ngot %#v\nwant %#v", got, tt.want) - } - } -} - -func TestStructInstance(t *testing.T) { - instance := struct { - I int - B bool `json:"b"` - P *int // either missing or nil - u int // unexported: not a property - }{1, true, nil, 0} - - for _, tt := range []struct { - s Schema - want bool - }{ - { - Schema{MinProperties: Ptr(4)}, - false, - }, - { - Schema{MinProperties: Ptr(3)}, - true, // P interpreted as present - }, - { - Schema{MaxProperties: Ptr(1)}, - false, - }, - { - Schema{MaxProperties: Ptr(2)}, - true, // P interpreted as absent - }, - { - Schema{Required: []string{"i"}}, // the name is "I" - false, - }, - { - Schema{Required: []string{"B"}}, // the name is "b" - false, - }, - { - Schema{PropertyNames: &Schema{MinLength: Ptr(2)}}, - false, - }, - { - Schema{Properties: map[string]*Schema{"b": {Type: "boolean"}}}, - true, - }, - { - Schema{Properties: map[string]*Schema{"b": {Type: "number"}}}, - false, - }, - { - Schema{Required: []string{"I"}}, - true, - }, - { - Schema{Required: []string{"I", "P"}}, - true, // P interpreted as present - }, - { - Schema{Required: []string{"I", "P"}, Properties: map[string]*Schema{"P": {Type: "number"}}}, - false, // P interpreted as present, but not a number - }, - { - Schema{Required: []string{"I"}, Properties: map[string]*Schema{"P": {Type: "number"}}}, - true, // P not required, so interpreted as absent - }, - { - Schema{Required: []string{"I"}, AdditionalProperties: falseSchema()}, - false, - }, - { - Schema{DependentRequired: map[string][]string{"b": {"u"}}}, - false, - }, - { - Schema{DependentSchemas: map[string]*Schema{"b": falseSchema()}}, - false, - }, - { - Schema{UnevaluatedProperties: falseSchema()}, - false, - }, - } { - res, err := tt.s.Resolve(nil) - if err != nil { - t.Fatal(err) - } - err = res.Validate(instance) - if err == nil && !tt.want { - t.Errorf("succeeded unexpectedly\nschema = %s", tt.s.json()) - } else if err != nil && tt.want { - t.Errorf("Validate: %v\nschema = %s", err, tt.s.json()) - } - } -} - -func mustMarshal(x any) json.RawMessage { - data, err := json.Marshal(x) - if err != nil { - panic(err) - } - return json.RawMessage(data) -} - -// loadRemote loads a remote reference used in the test suite. -func loadRemote(uri *url.URL) (*Schema, error) { - // Anything with localhost:1234 refers to the remotes directory in the test suite repo. - if uri.Host == "localhost:1234" { - return loadSchemaFromFile(filepath.FromSlash(filepath.Join("testdata/remotes", uri.Path))) - } - // One test needs the meta-schema files. - const metaPrefix = "https://json-schema.org/draft/2020-12/" - if after, ok := strings.CutPrefix(uri.String(), metaPrefix); ok { - return loadSchemaFromFile(filepath.FromSlash("meta-schemas/draft2020-12/" + after + ".json")) - } - return nil, fmt.Errorf("don't know how to load %s", uri) -} - -func loadSchemaFromFile(filename string) (*Schema, error) { - data, err := os.ReadFile(filename) - if err != nil { - return nil, err - } - var s Schema - if err := json.Unmarshal(data, &s); err != nil { - return nil, fmt.Errorf("unmarshaling JSON at %s: %w", filename, err) - } - return &s, nil -} diff --git a/mcp/client_list_test.go b/mcp/client_list_test.go index 497a9cd0..5b13a4c8 100644 --- a/mcp/client_list_test.go +++ b/mcp/client_list_test.go @@ -11,7 +11,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" ) diff --git a/mcp/client_test.go b/mcp/client_test.go index 73fe09e6..7920c55c 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -11,7 +11,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/google/jsonschema-go/jsonschema" ) type Item struct { diff --git a/mcp/example_middleware_test.go b/mcp/example_middleware_test.go index 1328473a..597b9dcd 100644 --- a/mcp/example_middleware_test.go +++ b/mcp/example_middleware_test.go @@ -11,7 +11,7 @@ import ( "os" "time" - "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" ) diff --git a/mcp/features_test.go b/mcp/features_test.go index e0165ecb..1c22ecd3 100644 --- a/mcp/features_test.go +++ b/mcp/features_test.go @@ -11,7 +11,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/google/jsonschema-go/jsonschema" ) type SayHiParams struct { diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index da53465c..48e95de2 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -21,8 +21,8 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" - "github.com/modelcontextprotocol/go-sdk/jsonschema" ) type hiParams struct { diff --git a/mcp/protocol.go b/mcp/protocol.go index 3ca6cb5e..d80a0787 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -14,7 +14,7 @@ import ( "encoding/json" "fmt" - "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/google/jsonschema-go/jsonschema" ) // Optional annotations for the client. The client can use annotations to inform diff --git a/mcp/server_test.go b/mcp/server_test.go index 5a161b72..6415decc 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -11,7 +11,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/google/jsonschema-go/jsonschema" ) type testItem struct { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index af06bd39..849f6026 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -24,9 +24,9 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" - "github.com/modelcontextprotocol/go-sdk/jsonschema" ) func TestStreamableTransports(t *testing.T) { diff --git a/mcp/tool.go b/mcp/tool.go index ed80b660..234cd659 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -11,7 +11,7 @@ import ( "fmt" "reflect" - "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/google/jsonschema-go/jsonschema" ) // A ToolHandler handles a call to tools/call. diff --git a/mcp/tool_test.go b/mcp/tool_test.go index 4d0a329b..52cac9fc 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -12,7 +12,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/google/jsonschema-go/jsonschema" ) // testToolHandler is used for type inference in TestNewServerTool. From 8186bf39c76ec971d779a705bcc42781c6ea7d99 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 7 Aug 2025 19:22:43 +0000 Subject: [PATCH 079/221] mcp: lock down Params and Result Add unexported methods to the Params and Result interface, so that they're harder to implement outside the mcp package. It looks like these are the only two interfaces we need to lock down: others are either intentionally open (Transport, Connection), or already closed (Session). Fixes #263 --- mcp/protocol.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ mcp/shared.go | 7 +++++++ 2 files changed, 53 insertions(+) diff --git a/mcp/protocol.go b/mcp/protocol.go index d80a0787..a2c0843e 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -78,6 +78,8 @@ type CallToolResultFor[Out any] struct { IsError bool `json:"isError,omitempty"` } +func (*CallToolResultFor[Out]) mcpResult() {} + // UnmarshalJSON handles the unmarshalling of content into the Content // interface. func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error { @@ -97,6 +99,7 @@ func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error { return nil } +func (x *CallToolParamsFor[Out]) mcpParams() {} func (x *CallToolParamsFor[Out]) GetProgressToken() any { return getProgressToken(x) } func (x *CallToolParamsFor[Out]) SetProgressToken(t any) { setProgressToken(x, t) } @@ -114,6 +117,7 @@ type CancelledParams struct { RequestID any `json:"requestId"` } +func (x *CancelledParams) mcpParams() {} func (x *CancelledParams) GetProgressToken() any { return getProgressToken(x) } func (x *CancelledParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -207,6 +211,8 @@ type CompleteParams struct { Ref *CompleteReference `json:"ref"` } +func (*CompleteParams) mcpParams() {} + type CompletionResultDetails struct { HasMore bool `json:"hasMore,omitempty"` Total int `json:"total,omitempty"` @@ -221,6 +227,8 @@ type CompleteResult struct { Completion CompletionResultDetails `json:"completion"` } +func (*CompleteResult) mcpResult() {} + type CreateMessageParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. @@ -245,6 +253,7 @@ type CreateMessageParams struct { Temperature float64 `json:"temperature,omitempty"` } +func (x *CreateMessageParams) mcpParams() {} func (x *CreateMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -264,6 +273,7 @@ type CreateMessageResult struct { StopReason string `json:"stopReason,omitempty"` } +func (*CreateMessageResult) mcpResult() {} func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { type result CreateMessageResult // avoid recursion var wire struct { @@ -291,6 +301,7 @@ type GetPromptParams struct { Name string `json:"name"` } +func (x *GetPromptParams) mcpParams() {} func (x *GetPromptParams) GetProgressToken() any { return getProgressToken(x) } func (x *GetPromptParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -304,6 +315,8 @@ type GetPromptResult struct { Messages []*PromptMessage `json:"messages"` } +func (*GetPromptResult) mcpResult() {} + type InitializeParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. @@ -315,6 +328,7 @@ type InitializeParams struct { ProtocolVersion string `json:"protocolVersion"` } +func (x *InitializeParams) mcpParams() {} func (x *InitializeParams) GetProgressToken() any { return getProgressToken(x) } func (x *InitializeParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -338,12 +352,15 @@ type InitializeResult struct { ServerInfo *Implementation `json:"serverInfo"` } +func (*InitializeResult) mcpResult() {} + type InitializedParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` } +func (x *InitializedParams) mcpParams() {} func (x *InitializedParams) GetProgressToken() any { return getProgressToken(x) } func (x *InitializedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -356,6 +373,7 @@ type ListPromptsParams struct { Cursor string `json:"cursor,omitempty"` } +func (x *ListPromptsParams) mcpParams() {} func (x *ListPromptsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListPromptsParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListPromptsParams) cursorPtr() *string { return &x.Cursor } @@ -371,6 +389,7 @@ type ListPromptsResult struct { Prompts []*Prompt `json:"prompts"` } +func (x *ListPromptsResult) mcpResult() {} func (x *ListPromptsResult) nextCursorPtr() *string { return &x.NextCursor } type ListResourceTemplatesParams struct { @@ -382,6 +401,7 @@ type ListResourceTemplatesParams struct { Cursor string `json:"cursor,omitempty"` } +func (x *ListResourceTemplatesParams) mcpParams() {} func (x *ListResourceTemplatesParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListResourceTemplatesParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListResourceTemplatesParams) cursorPtr() *string { return &x.Cursor } @@ -397,6 +417,7 @@ type ListResourceTemplatesResult struct { ResourceTemplates []*ResourceTemplate `json:"resourceTemplates"` } +func (x *ListResourceTemplatesResult) mcpResult() {} func (x *ListResourceTemplatesResult) nextCursorPtr() *string { return &x.NextCursor } type ListResourcesParams struct { @@ -408,6 +429,7 @@ type ListResourcesParams struct { Cursor string `json:"cursor,omitempty"` } +func (x *ListResourcesParams) mcpParams() {} func (x *ListResourcesParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListResourcesParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListResourcesParams) cursorPtr() *string { return &x.Cursor } @@ -423,6 +445,7 @@ type ListResourcesResult struct { Resources []*Resource `json:"resources"` } +func (x *ListResourcesResult) mcpResult() {} func (x *ListResourcesResult) nextCursorPtr() *string { return &x.NextCursor } type ListRootsParams struct { @@ -431,6 +454,7 @@ type ListRootsParams struct { Meta `json:"_meta,omitempty"` } +func (x *ListRootsParams) mcpParams() {} func (x *ListRootsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListRootsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -444,6 +468,8 @@ type ListRootsResult struct { Roots []*Root `json:"roots"` } +func (*ListRootsResult) mcpResult() {} + type ListToolsParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. @@ -453,6 +479,7 @@ type ListToolsParams struct { Cursor string `json:"cursor,omitempty"` } +func (x *ListToolsParams) mcpParams() {} func (x *ListToolsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListToolsParams) cursorPtr() *string { return &x.Cursor } @@ -468,6 +495,7 @@ type ListToolsResult struct { Tools []*Tool `json:"tools"` } +func (x *ListToolsResult) mcpResult() {} func (x *ListToolsResult) nextCursorPtr() *string { return &x.NextCursor } // The severity of a log message. @@ -489,6 +517,7 @@ type LoggingMessageParams struct { Logger string `json:"logger,omitempty"` } +func (x *LoggingMessageParams) mcpParams() {} func (x *LoggingMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *LoggingMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -550,6 +579,7 @@ type PingParams struct { Meta `json:"_meta,omitempty"` } +func (x *PingParams) mcpParams() {} func (x *PingParams) GetProgressToken() any { return getProgressToken(x) } func (x *PingParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -569,6 +599,8 @@ type ProgressNotificationParams struct { Total float64 `json:"total,omitempty"` } +func (*ProgressNotificationParams) mcpParams() {} + // A prompt or prompt template that the server offers. type Prompt struct { // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta @@ -606,6 +638,7 @@ type PromptListChangedParams struct { Meta `json:"_meta,omitempty"` } +func (x *PromptListChangedParams) mcpParams() {} func (x *PromptListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *PromptListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -646,6 +679,7 @@ type ReadResourceParams struct { URI string `json:"uri"` } +func (x *ReadResourceParams) mcpParams() {} func (x *ReadResourceParams) GetProgressToken() any { return getProgressToken(x) } func (x *ReadResourceParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -657,6 +691,8 @@ type ReadResourceResult struct { Contents []*ResourceContents `json:"contents"` } +func (*ReadResourceResult) mcpResult() {} + // A known resource that the server is capable of reading. type Resource struct { // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta @@ -697,6 +733,7 @@ type ResourceListChangedParams struct { Meta `json:"_meta,omitempty"` } +func (x *ResourceListChangedParams) mcpParams() {} func (x *ResourceListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ResourceListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -754,6 +791,7 @@ type RootsListChangedParams struct { Meta `json:"_meta,omitempty"` } +func (x *RootsListChangedParams) mcpParams() {} func (x *RootsListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -798,6 +836,7 @@ type SetLevelParams struct { Level LoggingLevel `json:"level"` } +func (x *SetLevelParams) mcpParams() {} func (x *SetLevelParams) GetProgressToken() any { return getProgressToken(x) } func (x *SetLevelParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -873,6 +912,7 @@ type ToolListChangedParams struct { Meta `json:"_meta,omitempty"` } +func (x *ToolListChangedParams) mcpParams() {} func (x *ToolListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -886,6 +926,8 @@ type SubscribeParams struct { URI string `json:"uri"` } +func (*SubscribeParams) mcpParams() {} + // Sent from the client to request cancellation of resources/updated // notifications from the server. This should follow a previous // resources/subscribe request. @@ -897,6 +939,8 @@ type UnsubscribeParams struct { URI string `json:"uri"` } +func (*UnsubscribeParams) mcpParams() {} + // A notification from the server to the client, informing it that a resource // has changed and may need to be read again. This should only be sent if the // client previously sent a resources/subscribe request. @@ -908,6 +952,8 @@ type ResourceUpdatedNotificationParams struct { URI string `json:"uri"` } +func (*ResourceUpdatedNotificationParams) mcpParams() {} + // TODO(jba): add CompleteRequest and related types. // TODO(jba): add ElicitRequest and related types. diff --git a/mcp/shared.go b/mcp/shared.go index 319071f2..8d0ceb1c 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -335,6 +335,9 @@ func setProgressToken(p Params, pt any) { // Params is a parameter (input) type for an MCP call or notification. type Params interface { + // mcpParams discourages implementation of Params outside of this package. + mcpParams() + // GetMeta returns metadata from a value. GetMeta() map[string]any // SetMeta sets the metadata on a value. @@ -356,6 +359,9 @@ type RequestParams interface { // Result is a result of an MCP call. type Result interface { + // mcpResult discourages implementation of Result outside of this package. + mcpResult() + // GetMeta returns metadata from a value. GetMeta() map[string]any // SetMeta sets the metadata on a value. @@ -366,6 +372,7 @@ type Result interface { // Those methods cannot return nil, because jsonrpc2 cannot handle nils. type emptyResult struct{} +func (*emptyResult) mcpResult() {} func (*emptyResult) GetMeta() map[string]any { panic("should never be called") } func (*emptyResult) SetMeta(map[string]any) { panic("should never be called") } From e2cf95ca223956bd7cc6179b64e5d26bd1c1bc1f Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 8 Aug 2025 11:07:03 -0400 Subject: [PATCH 080/221] auth: add OAuth authenticating middleware for server (#261) Add auth.RequireBearerToken and associated types. This piece of middleware authenticates clients using OAuth 2.0, as specified in the MCP spec. For #237. Usage: ``` st := mcp.NewStreamableServerTransport(...) http.Handle(path, auth.RequireBearerToken(verifier, nil)(st.ServeHTTP)) ``` --- auth/auth.go | 96 +++++++++++++++++++++++++++++++++++++++++++++++ auth/auth_test.go | 70 ++++++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+) create mode 100644 auth/auth.go create mode 100644 auth/auth_test.go diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 00000000..68873b48 --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,96 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "context" + "errors" + "net/http" + "slices" + "strings" + "time" +) + +type TokenInfo struct { + Scopes []string + Expiration time.Time +} + +type TokenVerifier func(ctx context.Context, token string) (*TokenInfo, error) + +type RequireBearerTokenOptions struct { + Scopes []string + ResourceMetadataURL string +} + +var ErrInvalidToken = errors.New("invalid token") + +type tokenInfoKey struct{} + +// RequireBearerToken returns a piece of middleware that verifies a bearer token using the verifier. +// If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds. +// If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header +// is populated to enable [protected resource metadata]. +// +// [protected resource metadata]: https://datatracker.ietf.org/doc/rfc9728 +func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions) func(http.Handler) http.Handler { + // Based on typescript-sdk/src/server/auth/middleware/bearerAuth.ts. + + return func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenInfo, errmsg, code := verify(r.Context(), verifier, opts, r.Header.Get("Authorization")) + if code != 0 { + if code == http.StatusUnauthorized || code == http.StatusForbidden { + if opts != nil && opts.ResourceMetadataURL != "" { + w.Header().Add("WWW-Authenticate", "Bearer resource_metadata="+opts.ResourceMetadataURL) + } + } + http.Error(w, errmsg, code) + return + } + r = r.WithContext(context.WithValue(r.Context(), tokenInfoKey{}, tokenInfo)) + handler.ServeHTTP(w, r) + }) + } +} + +func verify(ctx context.Context, verifier TokenVerifier, opts *RequireBearerTokenOptions, authHeader string) (_ *TokenInfo, errmsg string, code int) { + // Extract bearer token. + fields := strings.Fields(authHeader) + if len(fields) != 2 || strings.ToLower(fields[0]) != "bearer" { + return nil, "no bearer token", http.StatusUnauthorized + } + + // Verify the token and get information from it. + tokenInfo, err := verifier(ctx, fields[1]) + if err != nil { + if errors.Is(err, ErrInvalidToken) { + return nil, err.Error(), http.StatusUnauthorized + } + // TODO: the TS SDK distinguishes another error, OAuthError, and returns a 400. + // Investigate how that works. + // See typescript-sdk/src/server/auth/middleware/bearerAuth.ts. + return nil, err.Error(), http.StatusInternalServerError + } + + // Check scopes. + if opts != nil { + // Note: quadratic, but N is small. + for _, s := range opts.Scopes { + if !slices.Contains(tokenInfo.Scopes, s) { + return nil, "insufficient scope", http.StatusForbidden + } + } + } + + // Check expiration. + if tokenInfo.Expiration.IsZero() { + return nil, "token missing expiration", http.StatusUnauthorized + } + if tokenInfo.Expiration.Before(time.Now()) { + return nil, "token expired", http.StatusUnauthorized + } + return tokenInfo, "", 0 +} diff --git a/auth/auth_test.go b/auth/auth_test.go new file mode 100644 index 00000000..715b9bba --- /dev/null +++ b/auth/auth_test.go @@ -0,0 +1,70 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestVerify(t *testing.T) { + ctx := context.Background() + verifier := func(_ context.Context, token string) (*TokenInfo, error) { + switch token { + case "valid": + return &TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil + case "invalid": + return nil, ErrInvalidToken + case "noexp": + return &TokenInfo{}, nil + case "expired": + return &TokenInfo{Expiration: time.Now().Add(-time.Hour)}, nil + default: + return nil, errors.New("unknown") + } + } + + for _, tt := range []struct { + name string + opts *RequireBearerTokenOptions + header string + wantMsg string + wantCode int + }{ + { + "valid", nil, "Bearer valid", + "", 0, + }, + { + "bad header", nil, "Barer valid", + "no bearer token", 401, + }, + { + "invalid", nil, "bearer invalid", + "invalid token", 401, + }, + { + "no expiration", nil, "Bearer noexp", + "token missing expiration", 401, + }, + { + "expired", nil, "Bearer expired", + "token expired", 401, + }, + { + "missing scope", &RequireBearerTokenOptions{Scopes: []string{"s1"}}, "Bearer valid", + "insufficient scope", 403, + }, + } { + t.Run(tt.name, func(t *testing.T) { + _, gotMsg, gotCode := verify(ctx, verifier, tt.opts, tt.header) + if gotMsg != tt.wantMsg || gotCode != tt.wantCode { + t.Errorf("got (%q, %d), want (%q, %d)", gotMsg, gotCode, tt.wantMsg, tt.wantCode) + } + }) + } +} From c03cd684abdc73384db023d0003b47d0ac1c2deb Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 8 Aug 2025 14:32:17 +0000 Subject: [PATCH 081/221] internal/jsonrpc2: allow asynchronous writes This CL introduces a fundamental change to the jsonrpc2 library: connection writes, which were previously serialized by the jsonrpc2 library itself, are now allowed to be concurrent. The change in semantics of the jsonrpc2 library should hopefully be easy to review, since moving the synchronization to the Writer implementation is equivalent to the previous logic. However, this change is critical for the streamable client transport, because it allows for concurrent http requests to the server. Consider that a write is a POST to the server, and we don't know that write succeeded until we get the response header. Previously, we had the following problem: if the client POSTs a request, and the server blocks its response on a request made through the hanging GET, the client was unable to respond because the initial POST is still blocked. We could update our streamable server transport to force a flush of the response headers, but we can't guarantee that other servers behave the same way. Fundamentally, writes in the spec are asynchronous, and we need to support that. --- internal/jsonrpc2/conn.go | 13 ++++--------- internal/jsonrpc2/frame.go | 24 ++++++++++++++++++++---- mcp/sse.go | 10 +++++----- mcp/transport.go | 24 ++++++++++++++++++++++-- 4 files changed, 51 insertions(+), 20 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index fbe0688b..6f48c9ba 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -65,8 +65,7 @@ type Connection struct { state inFlightState // accessed only in updateInFlight done chan struct{} // closed (under stateMu) when state.closed is true and all goroutines have completed - writer chan Writer // 1-buffered; stores the writer when not in use - + writer Writer handler Handler onInternalError func(error) @@ -214,12 +213,11 @@ func NewConnection(ctx context.Context, cfg ConnectionConfig) *Connection { c := &Connection{ state: inFlightState{closer: cfg.Closer}, done: make(chan struct{}), - writer: make(chan Writer, 1), + writer: cfg.Writer, onDone: cfg.OnDone, onInternalError: cfg.OnInternalError, } c.handler = cfg.Bind(c) - c.writer <- cfg.Writer c.start(ctx, cfg.Reader, cfg.Preempter) return c } @@ -239,7 +237,6 @@ func bindConnection(bindCtx context.Context, rwc io.ReadWriteCloser, binder Bind c := &Connection{ state: inFlightState{closer: rwc}, done: make(chan struct{}), - writer: make(chan Writer, 1), onDone: onDone, } // It's tempting to set a finalizer on c to verify that the state has gone @@ -259,7 +256,7 @@ func bindConnection(bindCtx context.Context, rwc io.ReadWriteCloser, binder Bind } c.onInternalError = options.OnInternalError - c.writer <- framer.Writer(rwc) + c.writer = framer.Writer(rwc) reader := framer.Reader(rwc) c.start(ctx, reader, options.Preempter) return c @@ -728,9 +725,7 @@ func (c *Connection) processResult(from any, req *incomingRequest, result any, e // write is used by all things that write outgoing messages, including replies. // it makes sure that writes are atomic func (c *Connection) write(ctx context.Context, msg Message) error { - writer := <-c.writer - defer func() { c.writer <- writer }() - err := writer.Write(ctx, msg) + err := c.writer.Write(ctx, msg) if err != nil && ctx.Err() == nil { // The call to Write failed, and since ctx.Err() is nil we can't attribute diff --git a/internal/jsonrpc2/frame.go b/internal/jsonrpc2/frame.go index d5bbe75b..46fcc9db 100644 --- a/internal/jsonrpc2/frame.go +++ b/internal/jsonrpc2/frame.go @@ -12,12 +12,14 @@ import ( "io" "strconv" "strings" + "sync" ) // Reader abstracts the transport mechanics from the JSON RPC protocol. // A Conn reads messages from the reader it was provided on construction, // and assumes that each call to Read fully transfers a single message, // or returns an error. +// // A reader is not safe for concurrent use, it is expected it will be used by // a single Conn in a safe manner. type Reader interface { @@ -29,8 +31,9 @@ type Reader interface { // A Conn writes messages using the writer it was provided on construction, // and assumes that each call to Write fully transfers a single message, // or returns an error. -// A writer is not safe for concurrent use, it is expected it will be used by -// a single Conn in a safe manner. +// +// A writer must be safe for concurrent use, as writes may occur concurrently +// in practice: libraries may make calls or respond to requests asynchronously. type Writer interface { // Write sends a message to the stream. Write(context.Context, Message) error @@ -62,7 +65,10 @@ func RawFramer() Framer { return rawFramer{} } type rawFramer struct{} type rawReader struct{ in *json.Decoder } -type rawWriter struct{ out io.Writer } +type rawWriter struct { + mu sync.Mutex + out io.Writer +} func (rawFramer) Reader(rw io.Reader) Reader { return &rawReader{in: json.NewDecoder(rw)} @@ -92,10 +98,14 @@ func (w *rawWriter) Write(ctx context.Context, msg Message) error { return ctx.Err() default: } + data, err := EncodeMessage(msg) if err != nil { return fmt.Errorf("marshaling message: %v", err) } + + w.mu.Lock() + defer w.mu.Unlock() _, err = w.out.Write(data) return err } @@ -107,7 +117,10 @@ func HeaderFramer() Framer { return headerFramer{} } type headerFramer struct{} type headerReader struct{ in *bufio.Reader } -type headerWriter struct{ out io.Writer } +type headerWriter struct { + mu sync.Mutex + out io.Writer +} func (headerFramer) Reader(rw io.Reader) Reader { return &headerReader{in: bufio.NewReader(rw)} @@ -180,6 +193,9 @@ func (w *headerWriter) Write(ctx context.Context, msg Message) error { return ctx.Err() default: } + w.mu.Lock() + defer w.mu.Unlock() + data, err := EncodeMessage(msg) if err != nil { return fmt.Errorf("marshaling message: %v", err) diff --git a/mcp/sse.go b/mcp/sse.go index cf44276b..bdc4770b 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -155,7 +155,7 @@ func (t *SSEServerTransport) Connect(context.Context) (Connection, error) { if err != nil { return nil, err } - return sseServerConn{t}, nil + return &sseServerConn{t: t}, nil } func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -244,10 +244,10 @@ type sseServerConn struct { } // TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.) -func (s sseServerConn) SessionID() string { return "" } +func (s *sseServerConn) SessionID() string { return "" } // Read implements jsonrpc2.Reader. -func (s sseServerConn) Read(ctx context.Context) (jsonrpc.Message, error) { +func (s *sseServerConn) Read(ctx context.Context) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -259,7 +259,7 @@ func (s sseServerConn) Read(ctx context.Context) (jsonrpc.Message, error) { } // Write implements jsonrpc2.Writer. -func (s sseServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { +func (s *sseServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { if ctx.Err() != nil { return ctx.Err() } @@ -288,7 +288,7 @@ func (s sseServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { // It must be safe to call Close more than once, as the close may // asynchronously be initiated by either the server closing its connection, or // by the hanging GET exiting. -func (s sseServerConn) Close() error { +func (s *sseServerConn) Close() error { s.t.mu.Lock() defer s.t.mu.Unlock() if !s.t.closed { diff --git a/mcp/transport.go b/mcp/transport.go index f2d5c72d..ac778db6 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -38,9 +38,25 @@ type Transport interface { // A Connection is a logical bidirectional JSON-RPC connection. type Connection interface { + // Read reads the next message to process off the connection. + // + // Read need not be safe for concurrent use: Read is called in a + // concurrency-safe manner by the JSON-RPC library. Read(context.Context) (jsonrpc.Message, error) + + // Write writes a new message to the connection. + // + // Write may be called concurrently, as calls or reponses may occur + // concurrently in user code. Write(context.Context, jsonrpc.Message) error - Close() error // may be called concurrently by both peers + + // Close closes the connection. It is implicitly called whenever a Read or + // Write fails. + // + // Close may be called multiple times, potentially concurrently. + Close() error + + // TODO(#148): remove SessionID from this interface. SessionID() string } @@ -264,7 +280,8 @@ func (r rwc) Close() error { // // See [msgBatch] for more discussion of message batching. type ioConn struct { - rwc io.ReadWriteCloser // the underlying stream + writeMu sync.Mutex // guards Write, which must be concurrency safe. + rwc io.ReadWriteCloser // the underlying stream // incoming receives messages from the read loop started in [newIOConn]. incoming <-chan msgOrErr @@ -497,6 +514,9 @@ func (t *ioConn) Write(ctx context.Context, msg jsonrpc.Message) error { default: } + t.writeMu.Lock() + defer t.writeMu.Unlock() + // Batching support: if msg is a Response, it may have completed a batch, so // check that first. Otherwise, it is a request or notification, and we may // want to collect it into a batch before sending, if we're configured to use From 6ea7a6c65389276f257cc0371f87bb3748f146f5 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 8 Aug 2025 15:13:31 +0000 Subject: [PATCH 082/221] mcp: simplify and fix the streamable client transport Make several cleanups of the streamable client transport, encountered during work on JSON support for the streamable server: - The 'Close' condition is differentiated from asynchronous failures. A failure should unblock Read with an error, at which point the JSON-RPC connection will be broken and closed. - Fields are reordered in streamableClientConn to make guards more apparent. - The handling of sessionID is simplified: we simply set the session ID whenever we receive response headers. No need to have special handling for the first request, as the serializeation of session initialization is implemented in Client.Connect. - Since the above bullet makes Write a trivial wrapper around postMessage, the two methods are merged. - A bug is fixed where JSON responses were handled synchronously in Write. This lead to deadlock when a hanging client->server request is waiting on a server->client request. Now JSON is handled symmetrically to SSE: the Write returns once response headers are received. asynchronous to Write. - The httpConnection interface is renamed to clientConnection, and receive the entire InitializeResult. - streamableClientConn receivers are renamed to be consistently 'c'. --- mcp/client.go | 4 +- mcp/streamable.go | 294 ++++++++++++++++++++++------------------- mcp/streamable_test.go | 2 +- mcp/transport.go | 7 +- 4 files changed, 167 insertions(+), 140 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 7c68870c..add2b7da 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -134,8 +134,8 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e return nil, unsupportedProtocolVersionError{res.ProtocolVersion} } cs.initializeResult = res - if hc, ok := cs.mcpConn.(httpConnection); ok { - hc.setProtocolVersion(res.ProtocolVersion) + if hc, ok := cs.mcpConn.(clientConnection); ok { + hc.initialized(res) } if err := handleNotify(ctx, cs, notificationInitialized, &InitializedParams{}); err != nil { _ = cs.Close() diff --git a/mcp/streamable.go b/mcp/streamable.go index 86af05ce..d50b0af6 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -738,164 +738,201 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er ReconnectOptions: reconnOpts, ctx: connCtx, cancel: cancel, + failed: make(chan struct{}), } - // Start the persistent SSE listener right away. - // Section 2.2: The client MAY issue an HTTP GET to the MCP endpoint. - // This can be used to open an SSE stream, allowing the server to - // communicate to the client, without the client first sending data via HTTP POST. - go conn.handleSSE(nil, true) - return conn, nil } type streamableClientConn struct { url string + ReconnectOptions *StreamableReconnectOptions client *http.Client + ctx context.Context + cancel context.CancelFunc incoming chan []byte - done chan struct{} - ReconnectOptions *StreamableReconnectOptions + // Guard calls to Close, as it may be called multiple times. closeOnce sync.Once closeErr error - ctx context.Context - cancel context.CancelFunc + done chan struct{} // signal graceful termination - mu sync.Mutex - protocolVersion string - _sessionID string - err error + // Logical reads are distributed across multiple http requests. Whenever any + // of them fails to process their response, we must break the connection, by + // failing the pending Read. + // + // Achieve this by storing the failure message, and signalling when reads are + // broken. See also [streamableClientConn.fail] and + // [streamableClientConn.failure]. + failOnce sync.Once + _failure error + failed chan struct{} // signal failure + + // Guard the initialization state. + mu sync.Mutex + initializedResult *InitializeResult + sessionID string } -func (c *streamableClientConn) setProtocolVersion(s string) { +func (c *streamableClientConn) initialized(res *InitializeResult) { c.mu.Lock() - defer c.mu.Unlock() - c.protocolVersion = s + c.initializedResult = res + c.mu.Unlock() + + // Start the persistent SSE listener as soon as we have the initialized + // result. + // + // § 2.2: The client MAY issue an HTTP GET to the MCP endpoint. This can be + // used to open an SSE stream, allowing the server to communicate to the + // client, without the client first sending data via HTTP POST. + // + // We have to wait for initialized, because until we've received + // initialized, we don't know whether the server requires a sessionID. + // + // § 2.5: A server using the Streamable HTTP transport MAY assign a session + // ID at initialization time, by including it in an Mcp-Session-Id header + // on the HTTP response containing the InitializeResult. + go c.handleSSE(nil, true) +} + +// fail handles an asynchronous error while reading. +// +// If err is non-nil, it is terminal, and subsequent (or pending) Reads will +// fail. +func (c *streamableClientConn) fail(err error) { + if err != nil { + c.failOnce.Do(func() { + c._failure = err + close(c.failed) + }) + } +} + +func (c *streamableClientConn) failure() error { + select { + case <-c.failed: + return c._failure + default: + return nil + } } func (c *streamableClientConn) SessionID() string { c.mu.Lock() defer c.mu.Unlock() - return c._sessionID + return c.sessionID } // Read implements the [Connection] interface. -func (s *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { - s.mu.Lock() - err := s.err - s.mu.Unlock() - if err != nil { +func (c *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { + if err := c.failure(); err != nil { return nil, err } select { case <-ctx.Done(): return nil, ctx.Err() - case <-s.done: + case <-c.failed: + return nil, c.failure() + case <-c.done: return nil, io.EOF - case data := <-s.incoming: + case data := <-c.incoming: return jsonrpc2.DecodeMessage(data) } } // Write implements the [Connection] interface. -func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { - s.mu.Lock() - if s.err != nil { - s.mu.Unlock() - return s.err - } - - sessionID := s._sessionID - if sessionID == "" { - // Hold lock for the first request. - defer s.mu.Unlock() - } else { - s.mu.Unlock() - } - - gotSessionID, err := s.postMessage(ctx, sessionID, msg) - if err != nil { - if sessionID != "" { - // unlocked; lock to set err - s.mu.Lock() - defer s.mu.Unlock() - } - if s.err != nil { - s.err = err - } +func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { + if err := c.failure(); err != nil { return err } - if sessionID == "" { - // locked - s._sessionID = gotSessionID - } - - return nil -} - -// postMessage POSTs msg to the server and reads the response. -// It returns the session ID from the response. -func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) { data, err := jsonrpc2.EncodeMessage(msg) if err != nil { - return "", err + return err } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.url, bytes.NewReader(data)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data)) if err != nil { - return "", err - } - if s.protocolVersion != "" { - req.Header.Set(protocolVersionHeader, s.protocolVersion) - } - if sessionID != "" { - req.Header.Set(sessionIDHeader, sessionID) + return err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") + c.setMCPHeaders(req) - resp, err := s.client.Do(req) + resp, err := c.client.Do(req) if err != nil { - return "", err + return err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { // TODO: do a best effort read of the body here, and format it in the error. resp.Body.Close() - return "", fmt.Errorf("broken session: %v", resp.Status) + return fmt.Errorf("broken session: %v", resp.Status) } - sessionID = resp.Header.Get(sessionIDHeader) - switch ct := resp.Header.Get("Content-Type"); ct { - case "text/event-stream": - // Section 2.1: The SSE stream is initiated after a POST. - go s.handleSSE(resp, false) - case "application/json": - body, err := io.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return "", err + if sessionID := resp.Header.Get(sessionIDHeader); sessionID != "" { + c.mu.Lock() + hadSessionID := c.sessionID + if hadSessionID == "" { + c.sessionID = sessionID } - select { - case s.incoming <- body: - case <-s.done: - // The connection was closed by the client; exit gracefully. + c.mu.Unlock() + if hadSessionID != "" && hadSessionID != sessionID { + resp.Body.Close() + return fmt.Errorf("mismatching session IDs %q and %q", hadSessionID, sessionID) } - return sessionID, nil + } + if resp.StatusCode == http.StatusNoContent || resp.StatusCode == http.StatusAccepted { + resp.Body.Close() + return nil + } + + switch ct := resp.Header.Get("Content-Type"); ct { + case "application/json": + go c.handleJSON(resp) + + case "text/event-stream": + go c.handleSSE(resp, false) + default: resp.Body.Close() - return "", fmt.Errorf("unsupported content type %q", ct) + return fmt.Errorf("unsupported content type %q", ct) + } + return nil +} + +func (c *streamableClientConn) setMCPHeaders(req *http.Request) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.initializedResult != nil { + req.Header.Set(protocolVersionHeader, c.initializedResult.ProtocolVersion) + } + if c.sessionID != "" { + req.Header.Set(sessionIDHeader, c.sessionID) + } +} + +func (c *streamableClientConn) handleJSON(resp *http.Response) { + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + c.fail(err) + return + } + select { + case c.incoming <- body: + case <-c.done: + // The connection was closed by the client; exit gracefully. } - return sessionID, nil } // handleSSE manages the lifecycle of an SSE connection. It can be either // persistent (for the main GET listener) or temporary (for a POST response). -func (s *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool) { +func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool) { resp := initialResp var lastEventID string for { - eventID, clientClosed := s.processStream(resp) + eventID, clientClosed := c.processStream(resp) lastEventID = eventID // If the connection was closed by the client, we're done. @@ -909,14 +946,10 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response, persistent } // The stream was interrupted or ended by the server. Attempt to reconnect. - newResp, err := s.reconnect(lastEventID) + newResp, err := c.reconnect(lastEventID) if err != nil { - // All reconnection attempts failed. Set the final error, close the - // connection, and exit the goroutine. - s.mu.Lock() - s.err = err - s.mu.Unlock() - s.Close() + // All reconnection attempts failed: fail the connection. + c.fail(err) return } @@ -926,11 +959,12 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response, persistent } // processStream reads from a single response body, sending events to the -// incoming channel. It returns the ID of the last processed event, any error -// that occurred, and a flag indicating if the connection was closed by the client. -// If resp is nil, it returns "", false. -func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) { +// incoming channel. It returns the ID of the last processed event and a flag +// indicating if the connection was closed by the client. If resp is nil, it +// returns "", false. +func (c *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) { if resp == nil { + // TODO(rfindley): avoid this special handling. return "", false } @@ -945,8 +979,8 @@ func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID s } select { - case s.incoming <- evt.Data: - case <-s.done: + case c.incoming <- evt.Data: + case <-c.done: // The connection was closed by the client; exit gracefully. return "", true } @@ -958,15 +992,15 @@ func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID s // reconnect handles the logic of retrying a connection with an exponential // backoff strategy. It returns a new, valid HTTP response if successful, or // an error if all retries are exhausted. -func (s *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) { +func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) { var finalErr error - for attempt := 0; attempt < s.ReconnectOptions.MaxRetries; attempt++ { + for attempt := 0; attempt < c.ReconnectOptions.MaxRetries; attempt++ { select { - case <-s.done: + case <-c.done: return nil, fmt.Errorf("connection closed by client during reconnect") - case <-time.After(calculateReconnectDelay(s.ReconnectOptions, attempt)): - resp, err := s.establishSSE(lastEventID) + case <-time.After(calculateReconnectDelay(c.ReconnectOptions, attempt)): + resp, err := c.establishSSE(lastEventID) if err != nil { finalErr = err // Store the error and try again. continue @@ -983,9 +1017,9 @@ func (s *streamableClientConn) reconnect(lastEventID string) (*http.Response, er } // If the loop completes, all retries have failed. if finalErr != nil { - return nil, fmt.Errorf("connection failed after %d attempts: %w", s.ReconnectOptions.MaxRetries, finalErr) + return nil, fmt.Errorf("connection failed after %d attempts: %w", c.ReconnectOptions.MaxRetries, finalErr) } - return nil, fmt.Errorf("connection failed after %d attempts", s.ReconnectOptions.MaxRetries) + return nil, fmt.Errorf("connection failed after %d attempts", c.ReconnectOptions.MaxRetries) } // isResumable checks if an HTTP response indicates a valid SSE stream that can be processed. @@ -999,48 +1033,40 @@ func isResumable(resp *http.Response) bool { } // Close implements the [Connection] interface. -func (s *streamableClientConn) Close() error { - s.closeOnce.Do(func() { +func (c *streamableClientConn) Close() error { + c.closeOnce.Do(func() { // Cancel any hanging network requests. - s.cancel() - close(s.done) + c.cancel() + close(c.done) - req, err := http.NewRequest(http.MethodDelete, s.url, nil) + req, err := http.NewRequest(http.MethodDelete, c.url, nil) if err != nil { - s.closeErr = err + c.closeErr = err } else { - // TODO(jba): confirm that we don't need a lock here, or add locking. - if s.protocolVersion != "" { - req.Header.Set(protocolVersionHeader, s.protocolVersion) - } - req.Header.Set(sessionIDHeader, s._sessionID) - if _, err := s.client.Do(req); err != nil { - s.closeErr = err + c.setMCPHeaders(req) + if _, err := c.client.Do(req); err != nil { + c.closeErr = err } } }) - return s.closeErr + return c.closeErr } // establishSSE establishes the persistent SSE listening stream. // It is used for reconnect attempts using the Last-Event-ID header to // resume a broken stream where it left off. -func (s *streamableClientConn) establishSSE(lastEventID string) (*http.Response, error) { - req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, s.url, nil) +func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response, error) { + req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil) if err != nil { return nil, err } - s.mu.Lock() - if s._sessionID != "" { - req.Header.Set("Mcp-Session-Id", s._sessionID) - } - s.mu.Unlock() + c.setMCPHeaders(req) if lastEventID != "" { req.Header.Set("Last-Event-ID", lastEventID) } req.Header.Set("Accept", "text/event-stream") - return s.client.Do(req) + return c.client.Do(req) } // calculateReconnectDelay calculates a delay using exponential backoff with full jitter. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 849f6026..b18e2fb1 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -79,7 +79,7 @@ func TestStreamableTransports(t *testing.T) { if sid == "" { t.Error("empty session ID") } - if g, w := session.mcpConn.(*streamableClientConn).protocolVersion, latestProtocolVersion; g != w { + if g, w := session.mcpConn.(*streamableClientConn).initializedResult.ProtocolVersion, latestProtocolVersion; g != w { t.Fatalf("got protocol version %q, want %q", g, w) } // 4. The client calls the "greet" tool. diff --git a/mcp/transport.go b/mcp/transport.go index ac778db6..02b21806 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -60,10 +60,11 @@ type Connection interface { SessionID() string } -// An httpConnection is a [Connection] that runs over HTTP. -type httpConnection interface { +// A clientConnection is a [Connection] that is specific to the MCP client, and +// so may receive information about the client session. +type clientConnection interface { Connection - setProtocolVersion(string) + initialized(*InitializeResult) } // A StdioTransport is a [Transport] that communicates over stdin/stdout using From 0375535db46aa5b2117f4065a20a5d594ce8860d Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 8 Aug 2025 14:44:03 +0000 Subject: [PATCH 083/221] mcp: implement support for JSON responses in the MCP streamable server Add a new (currently unexported) jsonResponse option to StreamableServerTransportOptions, and use it to control the response content type, serving application/json if set. Additionally: - A transportOptions field is added to the StreamableHTTPOptions. - A messages iterator is added to encapsulate the iteration of stream messages, since the handling of JSON and SSE responses are otherwise very different. - The serving flow is refactored to avoid returning (statusCode, message), primarily because this seemed liable to lead to redundant calls to WriteHeader, because only local logic knows whether or not any data has been written to the response. - The serving flow is refactored to delegate to responseJSON and responseSSE, according to the currently unexported jsonResponse option. - A bug is fixed where all GET streams were considered persistent: the terminal condition req.Method == http.MethodPost && nOutstanding == 0 was not right: GET requests may implement stream resumption. Updates #211 --- mcp/client.go | 2 +- mcp/shared.go | 2 + mcp/streamable.go | 338 ++++++++++++++++++++++++++++------------- mcp/streamable_test.go | 152 +++++++++--------- 4 files changed, 315 insertions(+), 179 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index add2b7da..d3139f54 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -250,7 +250,7 @@ func (c *Client) listRoots(_ context.Context, _ *ClientSession, _ *ListRootsPara func (c *Client) createMessage(ctx context.Context, cs *ClientSession, params *CreateMessageParams) (*CreateMessageResult, error) { if c.opts.CreateMessageHandler == nil { // TODO: wrap or annotate this error? Pick a standard code? - return nil, &jsonrpc2.WireError{Code: CodeUnsupportedMethod, Message: "client does not support CreateMessage"} + return nil, jsonrpc2.NewError(CodeUnsupportedMethod, "client does not support CreateMessage") } return c.opts.CreateMessageHandler(ctx, cs, params) } diff --git a/mcp/shared.go b/mcp/shared.go index 8d0ceb1c..460fb6c3 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -275,6 +275,8 @@ func sessionMethod[S Session, P Params, R Result](f func(S, context.Context, P) // Error codes const ( + // TODO: should these be unexported? + CodeResourceNotFound = -32002 // The error code if the method exists and was called properly, but the peer does not support it. CodeUnsupportedMethod = -31001 diff --git a/mcp/streamable.go b/mcp/streamable.go index d50b0af6..108de5d2 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -7,9 +7,11 @@ package mcp import ( "bytes" "context" + "encoding/json" "errors" "fmt" "io" + "iter" "math" "math/rand/v2" "net/http" @@ -34,6 +36,7 @@ const ( // [MCP spec]: https://modelcontextprotocol.io/2025/03/26/streamable-http-transport.html type StreamableHTTPHandler struct { getServer func(*http.Request) *Server + opts StreamableHTTPOptions sessionsMu sync.Mutex sessions map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) @@ -44,6 +47,10 @@ type StreamableHTTPHandler struct { type StreamableHTTPOptions struct { // TODO: support configurable session ID generation (?) // TODO: support session retention (?) + + // transportOptions sets the streamable server transport options to use when + // establishing a new session. + transportOptions *StreamableServerTransportOptions } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -52,10 +59,14 @@ type StreamableHTTPOptions struct { // sessions. It is OK for getServer to return the same server multiple times. // If getServer returns nil, a 400 Bad Request will be served. func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *StreamableHTTPOptions) *StreamableHTTPHandler { - return &StreamableHTTPHandler{ + h := &StreamableHTTPHandler{ getServer: getServer, sessions: make(map[string]*StreamableServerTransport), } + if opts != nil { + h.opts = *opts + } + return h } // closeAll closes all ongoing sessions. @@ -134,7 +145,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } if session == nil { - s := NewStreamableServerTransport(randText(), nil) + s := NewStreamableServerTransport(randText(), h.opts.transportOptions) server := h.getServer(req) if server == nil { // The getServer argument to NewStreamableHTTPHandler returned nil. @@ -161,6 +172,15 @@ type StreamableServerTransportOptions struct { // Storage for events, to enable stream resumption. // If nil, a [MemoryEventStore] with the default maximum size will be used. EventStore EventStore + + // jsonResponse, if set, tells the server to prefer to respond to requests + // using application/json responses rather than text/event-stream. + // + // Specifically, responses will be application/json whenever incoming POST + // request contain only a single message. In this case, notifications or + // requests made within the context of a server request will be sent to the + // hanging GET request, if any. + jsonResponse bool } // NewStreamableServerTransport returns a new [StreamableServerTransport] with @@ -182,7 +202,11 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp streams: make(map[StreamID]*stream), requestStreams: make(map[jsonrpc.ID]StreamID), } - t.streams[0] = newStream(0) + // Stream 0 corresponds to the hanging 'GET'. + // + // It is always text/event-stream, since it must carry arbitrarily many + // messages. + t.streams[0] = newStream(0, false) if opts != nil { t.opts = *opts } @@ -199,7 +223,7 @@ func (t *StreamableServerTransport) SessionID() string { // A StreamableServerTransport implements the [Transport] interface for a // single session. type StreamableServerTransport struct { - nextStreamID atomic.Int64 // incrementing next stream ID + lastStreamID atomic.Int64 // last stream ID used, atomically incremented sessionID string opts StreamableServerTransportOptions @@ -210,11 +234,12 @@ type StreamableServerTransport struct { // Sessions are closed exactly once. isDone bool - // Sessions can have multiple logical connections, corresponding to HTTP - // requests. Additionally, logical sessions may be resumed by subsequent HTTP - // requests, when the session is terminated unexpectedly. + // Sessions can have multiple logical connections (which we call streams), + // corresponding to HTTP requests. Additionally, streams may be resumed by + // subsequent HTTP requests, when the HTTP connection is terminated + // unexpectedly. // - // Therefore, we use a logical connection ID to key the connection state, and + // Therefore, we use a logical stream ID to key the stream state, and // perform the accounting described below when incoming HTTP requests are // handled. @@ -227,10 +252,9 @@ type StreamableServerTransport struct { // requestStreams maps incoming requests to their logical stream ID. // - // Lifecycle: requestStreams persists for the duration of the session. + // Lifecycle: requestStreams persist for the duration of the session. // - // TODO(rfindley): clean up once requests are handled. See the TODO for streams - // above. + // TODO: clean up once requests are handled. See the TODO for streams above. requestStreams map[jsonrpc.ID]StreamID } @@ -245,6 +269,12 @@ type stream struct { // ID 0 is used for messages that don't correlate with an incoming request. id StreamID + // jsonResponse records whether this stream should respond with application/json + // instead of text/event-stream. + // + // See [StreamableServerTransportOptions.jsonResponse]. + jsonResponse bool + // signal is a 1-buffered channel, owned by an incoming HTTP request, that signals // that there are messages available to write into the HTTP response. // In addition, the presence of a channel guarantees that at most one HTTP response @@ -271,10 +301,11 @@ type stream struct { requests map[jsonrpc.ID]struct{} } -func newStream(id StreamID) *stream { +func newStream(id StreamID, jsonResponse bool) *stream { return &stream{ - id: id, - requests: make(map[jsonrpc.ID]struct{}), + id: id, + jsonResponse: jsonResponse, + requests: make(map[jsonrpc.ID]struct{}), } } @@ -319,25 +350,23 @@ type idContextKey struct{} // ServeHTTP handles a single HTTP request for the session. func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { - status := 0 - message := "" switch req.Method { case http.MethodGet: - status, message = t.serveGET(w, req) + t.serveGET(w, req) case http.MethodPost: - status, message = t.servePOST(w, req) + t.servePOST(w, req) default: // Should not be reached, as this is checked in StreamableHTTPHandler.ServeHTTP. w.Header().Set("Allow", "GET, POST") - status = http.StatusMethodNotAllowed - message = "unsupported method" - } - if status != 0 && status != http.StatusOK { - http.Error(w, message, status) + http.Error(w, "unsupported method", http.StatusMethodNotAllowed) } } -func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) (int, string) { +// serveGET streams messages to a hanging http GET, with stream ID and last +// message parsed from the Last-Event-ID header. +// +// It returns an HTTP status code and error message. +func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) { // connID 0 corresponds to the default GET request. id := StreamID(0) // By default, we haven't seen a last index. Since indices start at 0, we represent @@ -349,7 +378,8 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re var ok bool id, lastIdx, ok = parseEventID(eid) if !ok { - return http.StatusBadRequest, fmt.Sprintf("malformed Last-Event-ID %q", eid) + http.Error(w, fmt.Sprintf("malformed Last-Event-ID %q", eid), http.StatusBadRequest) + return } } @@ -357,31 +387,50 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re stream, ok := t.streams[id] t.mu.Unlock() if !ok { - return http.StatusBadRequest, "unknown stream" + http.Error(w, "unknown stream", http.StatusBadRequest) + return } if !stream.signal.CompareAndSwap(nil, signalChanPtr()) { // The CAS returned false, meaning that the comparison failed: stream.signal is not nil. - return http.StatusBadRequest, "stream ID conflicts with ongoing stream" + http.Error(w, "stream ID conflicts with ongoing stream", http.StatusConflict) + return } - return t.streamResponse(stream, w, req, lastIdx) + defer stream.signal.Store(nil) + persistent := id == 0 // Only the special stream 0 is a hanging get. + t.respondSSE(stream, w, req, lastIdx, persistent) } -func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) (int, string) { +// servePOST handles an incoming message, and replies with either an outgoing +// message stream or single response object, depending on whether the +// jsonResponse option is set. +// +// It returns an HTTP status code and error message. +func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) { if len(req.Header.Values("Last-Event-ID")) > 0 { - return http.StatusBadRequest, "can't send Last-Event-ID for POST request" + http.Error(w, "can't send Last-Event-ID for POST request", http.StatusBadRequest) + return } // Read incoming messages. body, err := io.ReadAll(req.Body) if err != nil { - return http.StatusBadRequest, "failed to read body" + http.Error(w, "failed to read body", http.StatusBadRequest) + return } if len(body) == 0 { - return http.StatusBadRequest, "POST requires a non-empty body" + http.Error(w, "POST requires a non-empty body", http.StatusBadRequest) + return } + // TODO(#21): if the negotiated protocol version is 2025-06-18 or later, + // we should not allow batching here. + // + // This also requires access to the negotiated version, which would either be + // set by the MCP-Protocol-Version header, or would require peeking into the + // session. incoming, _, err := readBatch(body) if err != nil { - return http.StatusBadRequest, fmt.Sprintf("malformed payload: %v", err) + http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) + return } requests := make(map[jsonrpc.ID]struct{}) for _, msg := range incoming { @@ -390,7 +439,8 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R // the HTTP request. If we didn't do this, a request with a bad method or // missing ID could be silently swallowed. if _, err := checkRequest(req, serverMethodInfos); err != nil { - return http.StatusBadRequest, err.Error() + http.Error(w, err.Error(), http.StatusBadRequest) + return } if req.ID.IsValid() { requests[req.ID] = struct{}{} @@ -398,39 +448,84 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R } } - // Update accounting for this request. - stream := newStream(StreamID(t.nextStreamID.Add(1))) - t.mu.Lock() - t.streams[stream.id] = stream + var stream *stream // if non-nil, used to handle requests + + // If we have requests, we need to handle responses along with any + // notifications or server->client requests made in the course of handling. + // Update accounting for this incoming payload. if len(requests) > 0 { - stream.requests = make(map[jsonrpc.ID]struct{}) - } - for reqID := range requests { - t.requestStreams[reqID] = stream.id - stream.requests[reqID] = struct{}{} + stream = newStream(StreamID(t.lastStreamID.Add(1)), t.opts.jsonResponse) + t.mu.Lock() + t.streams[stream.id] = stream + stream.requests = requests + for reqID := range requests { + t.requestStreams[reqID] = stream.id + } + t.mu.Unlock() + stream.signal.Store(signalChanPtr()) } - t.mu.Unlock() - stream.signal.Store(signalChanPtr()) // Publish incoming messages. for _, msg := range incoming { t.incoming <- msg } - // TODO(rfindley): consider optimizing for a single incoming request, by - // responding with application/json when there is only a single message in - // the response. - // (But how would we know there is only a single message? For example, couldn't - // a progress notification be sent before a response on the same context?) - return t.streamResponse(stream, w, req, -1) + if stream == nil { + w.WriteHeader(http.StatusAccepted) + return + } + + if stream.jsonResponse { + t.respondJSON(stream, w, req) + } else { + t.respondSSE(stream, w, req, -1, false) + } } -// lastIndex is the index of the last seen event if resuming, else -1. -func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) (int, string) { - defer stream.signal.Store(nil) +func (t *StreamableServerTransport) respondJSON(stream *stream, w http.ResponseWriter, req *http.Request) { + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Content-Type", "application/json") + w.Header().Set(sessionIDHeader, t.sessionID) + + var msgs []json.RawMessage + ctx := req.Context() + for msg, ok := range t.messages(ctx, stream, false) { + if !ok { + if ctx.Err() != nil { + w.WriteHeader(http.StatusNoContent) + return + } else { + http.Error(w, http.StatusText(http.StatusGone), http.StatusGone) + return + } + } + msgs = append(msgs, msg) + } + var data []byte + if len(msgs) == 1 { + data = []byte(msgs[0]) + } else { + // TODO: add tests for batch responses, or disallow them entirely. + var err error + data, err = json.Marshal(msgs) + if err != nil { + http.Error(w, fmt.Sprintf("internal error marshalling response: %v", err), http.StatusInternalServerError) + return + } + } + _, _ = w.Write(data) // ignore error: client disconnected +} +// lastIndex is the index of the last seen event if resuming, else -1. +func (t *StreamableServerTransport) respondSSE(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int, persistent bool) { writes := 0 + // Accept checked in [StreamableHTTPHandler] + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] + w.Header().Set("Connection", "keep-alive") + w.Header().Set(sessionIDHeader, t.sessionID) + // write one event containing data. write := func(data []byte) bool { lastIndex++ @@ -448,10 +543,13 @@ func (t *StreamableServerTransport) streamResponse(stream *stream, w http.Respon return true } - w.Header().Set(sessionIDHeader, t.sessionID) - w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] - w.Header().Set("Cache-Control", "no-cache, no-transform") - w.Header().Set("Connection", "keep-alive") + errorf := func(code int, format string, args ...any) { + if writes == 0 { + http.Error(w, fmt.Sprintf(format, args...), code) + } else { + // TODO(#170): log when we add server-side logging + } + } if lastIndex >= 0 { // Resume. @@ -464,65 +562,83 @@ func (t *StreamableServerTransport) streamResponse(stream *stream, w http.Respon if errors.Is(err, ErrEventsPurged) { status = http.StatusInsufficientStorage } - return status, err.Error() + errorf(status, "failed to read events: %v", err) + return } // The iterator yields events beginning just after lastIndex, or it would have // yielded an error. if !write(data) { - return 0, "" + return } } } -stream: // Repeatedly collect pending outgoing events and send them. - for { - t.mu.Lock() - outgoing := stream.outgoing - stream.outgoing = nil - nOutstanding := len(stream.requests) - t.mu.Unlock() - - for _, data := range outgoing { - if err := t.opts.EventStore.Append(req.Context(), t.SessionID(), stream.id, data); err != nil { - return http.StatusInternalServerError, err.Error() - } - if !write(data) { - return 0, "" + ctx := req.Context() + for msg, ok := range t.messages(ctx, stream, persistent) { + if !ok { + if ctx.Err() != nil && writes == 0 { + // This probably doesn't matter, but respond with NoContent if the client disconnected. + w.WriteHeader(http.StatusNoContent) + } else { + errorf(http.StatusGone, "stream terminated") } + return } + if err := t.opts.EventStore.Append(req.Context(), t.SessionID(), stream.id, msg); err != nil { + errorf(http.StatusInternalServerError, "storing event: %v", err.Error()) + return + } + if !write(msg) { + return + } + } +} - // If all requests have been handled and replied to, we should terminate this connection. - // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." - // §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server - // We only want to terminate POSTs, and GETs that are replaying. The general-purpose GET - // (stream ID 0) will never have requests, and should remain open indefinitely. - // TODO: implement the GET case. - if req.Method == http.MethodPost && nOutstanding == 0 { - if writes == 0 { - // Spec: If the server accepts the input, the server MUST return HTTP - // status code 202 Accepted with no body. - w.WriteHeader(http.StatusAccepted) +// messages iterates over messages sent to the current stream. +// +// The first iterated value is the received JSON message. The second iterated +// value is an OK value indicating whether the stream terminated normally. +// +// If the stream did not terminate normally, it is either because ctx was +// cancelled, or the connection is closed: check the ctx.Err() to differentiate +// these cases. +func (t *StreamableServerTransport) messages(ctx context.Context, stream *stream, persistent bool) iter.Seq2[json.RawMessage, bool] { + return func(yield func(json.RawMessage, bool) bool) { + for { + t.mu.Lock() + outgoing := stream.outgoing + stream.outgoing = nil + nOutstanding := len(stream.requests) + t.mu.Unlock() + + for _, data := range outgoing { + if !yield(data, true) { + return + } } - return 0, "" - } - select { - case <-*stream.signal.Load(): // there are new outgoing messages - // return to top of loop - case <-t.done: // session is closed - if writes == 0 { - return http.StatusGone, "session terminated" + // If all requests have been handled and replied to, we should terminate this connection. + // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." + // §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server + // We only want to terminate POSTs, and GETs that are replaying. The general-purpose GET + // (stream ID 0) will never have requests, and should remain open indefinitely. + if nOutstanding == 0 && !persistent { + return } - break stream - case <-req.Context().Done(): - if writes == 0 { - w.WriteHeader(http.StatusNoContent) + + select { + case <-*stream.signal.Load(): // there are new outgoing messages + // return to top of loop + case <-t.done: // session is closed + yield(nil, false) + return + case <-ctx.Done(): + yield(nil, false) + return } - break stream } } - return 0, "" } // Event IDs: encode both the logical connection ID and the index, as @@ -582,7 +698,7 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa isResponse = true } else { // Otherwise, we check to see if it request was made in the context of an - // ongoing request. This may not be the case if the request way made with + // ongoing request. This may not be the case if the request was made with // an unrelated context. if v := ctx.Value(idContextKey{}); v != nil { forRequest = v.(jsonrpc.ID) @@ -593,10 +709,10 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa // // For messages sent outside of a request context, this is the default // connection 0. - var forConn StreamID + var forStream StreamID if forRequest.IsValid() { t.mu.Lock() - forConn = t.requestStreams[forRequest] + forStream = t.requestStreams[forRequest] t.mu.Unlock() } @@ -611,15 +727,19 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa return errors.New("session is closed") } - stream := t.streams[forConn] + stream := t.streams[forStream] if stream == nil { - return fmt.Errorf("no stream with ID %d", forConn) + return fmt.Errorf("no stream with ID %d", forStream) } - if len(stream.requests) == 0 && forConn != 0 { - // No outstanding requests for this connection, which means it is logically - // done. This is a sequencing violation from the server, so we should report - // a side-channel error here. Put the message on the general queue to avoid - // dropping messages. + + // Special case a few conditions where we fall back on stream 0 (the hanging GET): + // + // - if forStream is known, but the associated stream is logically complete + // - if the stream is application/json, but the message is not a response + // + // TODO(rfindley): either of these, particularly the first, might be + // considered a bug in the server. Report it through a side-channel? + if len(stream.requests) == 0 && forStream != 0 || stream.jsonResponse && !isResponse { stream = t.streams[0] } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index b18e2fb1..24368b00 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -35,77 +35,90 @@ func TestStreamableTransports(t *testing.T) { ctx := context.Background() - // 1. Create a server with a simple "greet" tool. - server := NewServer(testImpl, nil) - AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) - // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a - // cookie-checking middleware. - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) - var header http.Header - httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - header = r.Header - cookie, err := r.Cookie("test-cookie") - if err != nil { - t.Errorf("missing cookie: %v", err) - } else if cookie.Value != "test-value" { - t.Errorf("got cookie %q, want %q", cookie.Value, "test-value") - } - handler.ServeHTTP(w, r) - })) - defer httpServer.Close() + for _, useJSON := range []bool{false, true} { + t.Run(fmt.Sprintf("JSONResponse=%v", useJSON), func(t *testing.T) { + // 1. Create a server with a simple "greet" tool. + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + + // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a + // cookie-checking middleware. + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + transportOptions: &StreamableServerTransportOptions{jsonResponse: useJSON}, + }) - // 3. Create a client and connect it to the server using our StreamableClientTransport. - // Check that all requests honor a custom client. - jar, err := cookiejar.New(nil) - if err != nil { - t.Fatal(err) - } - u, err := url.Parse(httpServer.URL) - if err != nil { - t.Fatal(err) - } - jar.SetCookies(u, []*http.Cookie{{Name: "test-cookie", Value: "test-value"}}) - httpClient := &http.Client{Jar: jar} - transport := NewStreamableClientTransport(httpServer.URL, &StreamableClientTransportOptions{ - HTTPClient: httpClient, - }) - client := NewClient(testImpl, nil) - session, err := client.Connect(ctx, transport) - if err != nil { - t.Fatalf("client.Connect() failed: %v", err) - } - defer session.Close() - sid := session.ID() - if sid == "" { - t.Error("empty session ID") - } - if g, w := session.mcpConn.(*streamableClientConn).initializedResult.ProtocolVersion, latestProtocolVersion; g != w { - t.Fatalf("got protocol version %q, want %q", g, w) - } - // 4. The client calls the "greet" tool. - params := &CallToolParams{ - Name: "greet", - Arguments: map[string]any{"name": "streamy"}, - } - got, err := session.CallTool(ctx, params) - if err != nil { - t.Fatalf("CallTool() failed: %v", err) - } - if g := session.ID(); g != sid { - t.Errorf("session ID: got %q, want %q", g, sid) - } - if g, w := header.Get(protocolVersionHeader), latestProtocolVersion; g != w { - t.Errorf("got protocol version header %q, want %q", g, w) - } + var ( + headerMu sync.Mutex + lastHeader http.Header + ) + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headerMu.Lock() + lastHeader = r.Header + headerMu.Unlock() + cookie, err := r.Cookie("test-cookie") + if err != nil { + t.Errorf("missing cookie: %v", err) + } else if cookie.Value != "test-value" { + t.Errorf("got cookie %q, want %q", cookie.Value, "test-value") + } + handler.ServeHTTP(w, r) + })) + defer httpServer.Close() - // 5. Verify that the correct response is received. - want := &CallToolResult{ - Content: []Content{ - &TextContent{Text: "hi streamy"}, - }, - } - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff) + // 3. Create a client and connect it to the server using our StreamableClientTransport. + // Check that all requests honor a custom client. + jar, err := cookiejar.New(nil) + if err != nil { + t.Fatal(err) + } + u, err := url.Parse(httpServer.URL) + if err != nil { + t.Fatal(err) + } + jar.SetCookies(u, []*http.Cookie{{Name: "test-cookie", Value: "test-value"}}) + httpClient := &http.Client{Jar: jar} + transport := NewStreamableClientTransport(httpServer.URL, &StreamableClientTransportOptions{ + HTTPClient: httpClient, + }) + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer session.Close() + sid := session.ID() + if sid == "" { + t.Error("empty session ID") + } + if g, w := session.mcpConn.(*streamableClientConn).initializedResult.ProtocolVersion, latestProtocolVersion; g != w { + t.Fatalf("got protocol version %q, want %q", g, w) + } + // 4. The client calls the "greet" tool. + params := &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"name": "streamy"}, + } + got, err := session.CallTool(ctx, params) + if err != nil { + t.Fatalf("CallTool() failed: %v", err) + } + if g := session.ID(); g != sid { + t.Errorf("session ID: got %q, want %q", g, sid) + } + if g, w := lastHeader.Get(protocolVersionHeader), latestProtocolVersion; g != w { + t.Errorf("got protocol version header %q, want %q", g, w) + } + + // 5. Verify that the correct response is received. + want := &CallToolResult{ + Content: []Content{ + &TextContent{Text: "hi streamy"}, + }, + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff) + } + }) } } @@ -745,6 +758,7 @@ func TestStreamableClientTransportApplicationJSON(t *testing.T) { t.Fatal(err) } w.Header().Set("Content-Type", "application/json") + w.Header().Set("Mcp-Session-Id", "123") w.Write(data) } From 388e000aa2ff01a2147c7ca17f2bace45803eb9c Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Sat, 9 Aug 2025 11:04:44 -0400 Subject: [PATCH 084/221] .github: update go 1.25 builds to rc3 (#271) Now that rc3 is out, update our 1.25 builds to use it. --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d5009944..54b36331 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,7 +31,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: [ '1.23', '1.24', '1.25.0-rc.2' ] + go: [ '1.23', '1.24', '1.25.0-rc.3' ] steps: - name: Check out code uses: actions/checkout@v4 From 4e413da7710d769262fe5a35f8d09772abc6a09e Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Sat, 9 Aug 2025 11:06:59 -0400 Subject: [PATCH 085/221] mcp: rename mcp{Params,Result} to is{Params,Result} (#270) For some reason, when making these marker methods, the standard naming convention escaped me. --- mcp/protocol.go | 70 ++++++++++++++++++++++++------------------------- mcp/shared.go | 10 +++---- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/mcp/protocol.go b/mcp/protocol.go index a2c0843e..d2d343b8 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -78,7 +78,7 @@ type CallToolResultFor[Out any] struct { IsError bool `json:"isError,omitempty"` } -func (*CallToolResultFor[Out]) mcpResult() {} +func (*CallToolResultFor[Out]) isResult() {} // UnmarshalJSON handles the unmarshalling of content into the Content // interface. @@ -99,7 +99,7 @@ func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error { return nil } -func (x *CallToolParamsFor[Out]) mcpParams() {} +func (x *CallToolParamsFor[Out]) isParams() {} func (x *CallToolParamsFor[Out]) GetProgressToken() any { return getProgressToken(x) } func (x *CallToolParamsFor[Out]) SetProgressToken(t any) { setProgressToken(x, t) } @@ -117,7 +117,7 @@ type CancelledParams struct { RequestID any `json:"requestId"` } -func (x *CancelledParams) mcpParams() {} +func (x *CancelledParams) isParams() {} func (x *CancelledParams) GetProgressToken() any { return getProgressToken(x) } func (x *CancelledParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -211,7 +211,7 @@ type CompleteParams struct { Ref *CompleteReference `json:"ref"` } -func (*CompleteParams) mcpParams() {} +func (*CompleteParams) isParams() {} type CompletionResultDetails struct { HasMore bool `json:"hasMore,omitempty"` @@ -227,7 +227,7 @@ type CompleteResult struct { Completion CompletionResultDetails `json:"completion"` } -func (*CompleteResult) mcpResult() {} +func (*CompleteResult) isResult() {} type CreateMessageParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -253,7 +253,7 @@ type CreateMessageParams struct { Temperature float64 `json:"temperature,omitempty"` } -func (x *CreateMessageParams) mcpParams() {} +func (x *CreateMessageParams) isParams() {} func (x *CreateMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -273,7 +273,7 @@ type CreateMessageResult struct { StopReason string `json:"stopReason,omitempty"` } -func (*CreateMessageResult) mcpResult() {} +func (*CreateMessageResult) isResult() {} func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { type result CreateMessageResult // avoid recursion var wire struct { @@ -301,7 +301,7 @@ type GetPromptParams struct { Name string `json:"name"` } -func (x *GetPromptParams) mcpParams() {} +func (x *GetPromptParams) isParams() {} func (x *GetPromptParams) GetProgressToken() any { return getProgressToken(x) } func (x *GetPromptParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -315,7 +315,7 @@ type GetPromptResult struct { Messages []*PromptMessage `json:"messages"` } -func (*GetPromptResult) mcpResult() {} +func (*GetPromptResult) isResult() {} type InitializeParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -328,7 +328,7 @@ type InitializeParams struct { ProtocolVersion string `json:"protocolVersion"` } -func (x *InitializeParams) mcpParams() {} +func (x *InitializeParams) isParams() {} func (x *InitializeParams) GetProgressToken() any { return getProgressToken(x) } func (x *InitializeParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -352,7 +352,7 @@ type InitializeResult struct { ServerInfo *Implementation `json:"serverInfo"` } -func (*InitializeResult) mcpResult() {} +func (*InitializeResult) isResult() {} type InitializedParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -360,7 +360,7 @@ type InitializedParams struct { Meta `json:"_meta,omitempty"` } -func (x *InitializedParams) mcpParams() {} +func (x *InitializedParams) isParams() {} func (x *InitializedParams) GetProgressToken() any { return getProgressToken(x) } func (x *InitializedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -373,7 +373,7 @@ type ListPromptsParams struct { Cursor string `json:"cursor,omitempty"` } -func (x *ListPromptsParams) mcpParams() {} +func (x *ListPromptsParams) isParams() {} func (x *ListPromptsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListPromptsParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListPromptsParams) cursorPtr() *string { return &x.Cursor } @@ -389,7 +389,7 @@ type ListPromptsResult struct { Prompts []*Prompt `json:"prompts"` } -func (x *ListPromptsResult) mcpResult() {} +func (x *ListPromptsResult) isResult() {} func (x *ListPromptsResult) nextCursorPtr() *string { return &x.NextCursor } type ListResourceTemplatesParams struct { @@ -401,7 +401,7 @@ type ListResourceTemplatesParams struct { Cursor string `json:"cursor,omitempty"` } -func (x *ListResourceTemplatesParams) mcpParams() {} +func (x *ListResourceTemplatesParams) isParams() {} func (x *ListResourceTemplatesParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListResourceTemplatesParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListResourceTemplatesParams) cursorPtr() *string { return &x.Cursor } @@ -417,7 +417,7 @@ type ListResourceTemplatesResult struct { ResourceTemplates []*ResourceTemplate `json:"resourceTemplates"` } -func (x *ListResourceTemplatesResult) mcpResult() {} +func (x *ListResourceTemplatesResult) isResult() {} func (x *ListResourceTemplatesResult) nextCursorPtr() *string { return &x.NextCursor } type ListResourcesParams struct { @@ -429,7 +429,7 @@ type ListResourcesParams struct { Cursor string `json:"cursor,omitempty"` } -func (x *ListResourcesParams) mcpParams() {} +func (x *ListResourcesParams) isParams() {} func (x *ListResourcesParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListResourcesParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListResourcesParams) cursorPtr() *string { return &x.Cursor } @@ -445,7 +445,7 @@ type ListResourcesResult struct { Resources []*Resource `json:"resources"` } -func (x *ListResourcesResult) mcpResult() {} +func (x *ListResourcesResult) isResult() {} func (x *ListResourcesResult) nextCursorPtr() *string { return &x.NextCursor } type ListRootsParams struct { @@ -454,7 +454,7 @@ type ListRootsParams struct { Meta `json:"_meta,omitempty"` } -func (x *ListRootsParams) mcpParams() {} +func (x *ListRootsParams) isParams() {} func (x *ListRootsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListRootsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -468,7 +468,7 @@ type ListRootsResult struct { Roots []*Root `json:"roots"` } -func (*ListRootsResult) mcpResult() {} +func (*ListRootsResult) isResult() {} type ListToolsParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -479,7 +479,7 @@ type ListToolsParams struct { Cursor string `json:"cursor,omitempty"` } -func (x *ListToolsParams) mcpParams() {} +func (x *ListToolsParams) isParams() {} func (x *ListToolsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListToolsParams) cursorPtr() *string { return &x.Cursor } @@ -495,7 +495,7 @@ type ListToolsResult struct { Tools []*Tool `json:"tools"` } -func (x *ListToolsResult) mcpResult() {} +func (x *ListToolsResult) isResult() {} func (x *ListToolsResult) nextCursorPtr() *string { return &x.NextCursor } // The severity of a log message. @@ -517,7 +517,7 @@ type LoggingMessageParams struct { Logger string `json:"logger,omitempty"` } -func (x *LoggingMessageParams) mcpParams() {} +func (x *LoggingMessageParams) isParams() {} func (x *LoggingMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *LoggingMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -579,7 +579,7 @@ type PingParams struct { Meta `json:"_meta,omitempty"` } -func (x *PingParams) mcpParams() {} +func (x *PingParams) isParams() {} func (x *PingParams) GetProgressToken() any { return getProgressToken(x) } func (x *PingParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -599,7 +599,7 @@ type ProgressNotificationParams struct { Total float64 `json:"total,omitempty"` } -func (*ProgressNotificationParams) mcpParams() {} +func (*ProgressNotificationParams) isParams() {} // A prompt or prompt template that the server offers. type Prompt struct { @@ -638,7 +638,7 @@ type PromptListChangedParams struct { Meta `json:"_meta,omitempty"` } -func (x *PromptListChangedParams) mcpParams() {} +func (x *PromptListChangedParams) isParams() {} func (x *PromptListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *PromptListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -679,7 +679,7 @@ type ReadResourceParams struct { URI string `json:"uri"` } -func (x *ReadResourceParams) mcpParams() {} +func (x *ReadResourceParams) isParams() {} func (x *ReadResourceParams) GetProgressToken() any { return getProgressToken(x) } func (x *ReadResourceParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -691,7 +691,7 @@ type ReadResourceResult struct { Contents []*ResourceContents `json:"contents"` } -func (*ReadResourceResult) mcpResult() {} +func (*ReadResourceResult) isResult() {} // A known resource that the server is capable of reading. type Resource struct { @@ -733,7 +733,7 @@ type ResourceListChangedParams struct { Meta `json:"_meta,omitempty"` } -func (x *ResourceListChangedParams) mcpParams() {} +func (x *ResourceListChangedParams) isParams() {} func (x *ResourceListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ResourceListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -791,7 +791,7 @@ type RootsListChangedParams struct { Meta `json:"_meta,omitempty"` } -func (x *RootsListChangedParams) mcpParams() {} +func (x *RootsListChangedParams) isParams() {} func (x *RootsListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -836,7 +836,7 @@ type SetLevelParams struct { Level LoggingLevel `json:"level"` } -func (x *SetLevelParams) mcpParams() {} +func (x *SetLevelParams) isParams() {} func (x *SetLevelParams) GetProgressToken() any { return getProgressToken(x) } func (x *SetLevelParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -912,7 +912,7 @@ type ToolListChangedParams struct { Meta `json:"_meta,omitempty"` } -func (x *ToolListChangedParams) mcpParams() {} +func (x *ToolListChangedParams) isParams() {} func (x *ToolListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -926,7 +926,7 @@ type SubscribeParams struct { URI string `json:"uri"` } -func (*SubscribeParams) mcpParams() {} +func (*SubscribeParams) isParams() {} // Sent from the client to request cancellation of resources/updated // notifications from the server. This should follow a previous @@ -939,7 +939,7 @@ type UnsubscribeParams struct { URI string `json:"uri"` } -func (*UnsubscribeParams) mcpParams() {} +func (*UnsubscribeParams) isParams() {} // A notification from the server to the client, informing it that a resource // has changed and may need to be read again. This should only be sent if the @@ -952,7 +952,7 @@ type ResourceUpdatedNotificationParams struct { URI string `json:"uri"` } -func (*ResourceUpdatedNotificationParams) mcpParams() {} +func (*ResourceUpdatedNotificationParams) isParams() {} // TODO(jba): add CompleteRequest and related types. diff --git a/mcp/shared.go b/mcp/shared.go index 460fb6c3..e3688641 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -337,8 +337,8 @@ func setProgressToken(p Params, pt any) { // Params is a parameter (input) type for an MCP call or notification. type Params interface { - // mcpParams discourages implementation of Params outside of this package. - mcpParams() + // isParams discourages implementation of Params outside of this package. + isParams() // GetMeta returns metadata from a value. GetMeta() map[string]any @@ -361,8 +361,8 @@ type RequestParams interface { // Result is a result of an MCP call. type Result interface { - // mcpResult discourages implementation of Result outside of this package. - mcpResult() + // isResult discourages implementation of Result outside of this package. + isResult() // GetMeta returns metadata from a value. GetMeta() map[string]any @@ -374,7 +374,7 @@ type Result interface { // Those methods cannot return nil, because jsonrpc2 cannot handle nils. type emptyResult struct{} -func (*emptyResult) mcpResult() {} +func (*emptyResult) isResult() {} func (*emptyResult) GetMeta() map[string]any { panic("should never be called") } func (*emptyResult) SetMeta(map[string]any) { panic("should never be called") } From cccc086518677026c16924cf8052d32fec188d1c Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 9 Aug 2025 10:43:15 -0400 Subject: [PATCH 086/221] mcp: clarify Server.Run use Make it clear that it's not needed for an HTTP server. --- mcp/server.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mcp/server.go b/mcp/server.go index e69a872e..d81dec60 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -513,6 +513,10 @@ func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *Uns // advertise the capability for tools, including the ability to send list-changed notifications. // If no tools have been added, the server will not have the tool capability. // The same goes for other features like prompts and resources. +// +// Run is a convenience for servers that handle a single session (or one session at a time). +// It need not be called on servers that are used for multiple concurrent connections, +// as with [StreamableHTTPHandler]. func (s *Server) Run(ctx context.Context, t Transport) error { ss, err := s.Connect(ctx, t) if err != nil { From be1ddf5f801cc6eb3dda401e5ed3ba6956017f89 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 9 Aug 2025 12:12:38 -0400 Subject: [PATCH 087/221] jsonrpc: paper over race in jsonrpc2 test In jsonprc2_test.go, binder.Bind starts a goroutine. That goroutine begins to run during jsonrpc2.bindConnection, and can race with the setting of conn.write in bindConnection. This PR adds a sleep, which is a poor way to deal with the race, but the least invasive change. Better ones include running the test function after Dial returns, or adding a Connection.Ready method to detect when initialization is complete. --- internal/jsonrpc2/jsonrpc2_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/jsonrpc2/jsonrpc2_test.go b/internal/jsonrpc2/jsonrpc2_test.go index 35f4e7f9..16a5039b 100644 --- a/internal/jsonrpc2/jsonrpc2_test.go +++ b/internal/jsonrpc2/jsonrpc2_test.go @@ -11,6 +11,7 @@ import ( "path" "reflect" "testing" + "time" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) @@ -143,6 +144,8 @@ func testConnection(t *testing.T, framer jsonrpc2.Framer) { t.Run(test.Name(), func(t *testing.T) { client, err := jsonrpc2.Dial(ctx, listener.Dialer(), binder{framer, func(h *handler) { + // Sleep a little to a void a race with setting conn.writer in jsonrpc2.bindConnection. + time.Sleep(50 * time.Millisecond) defer h.conn.Close() test.Invoke(t, ctx, h) if call, ok := test.(*call); ok { From cb27392b2dcd8728b3cdecdaa19f19c537a87fde Mon Sep 17 00:00:00 2001 From: Kartik Verma Date: Mon, 11 Aug 2025 18:03:50 +0530 Subject: [PATCH 088/221] .gitignore: add Add `.gitignore` for some common files. --- .gitignore | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..17341605 --- /dev/null +++ b/.gitignore @@ -0,0 +1,38 @@ +# Builds +*.exe +*.exe~ +dist/ +build/ +bin/ +*.tmp + +# IDE files +.vscode/ +*.code-workspace +.idea/ +*~ + +# Go Specific +*.prof +*.pprof +*.out +*.coverage +coverage.txt +coverage.html + +# OS generated files +# macOS +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Windows +Desktop.ini +$RECYCLE.BIN/ + +# Linux +.nfs* From 119a58314d8185bffebdd873406d1f4f788bd0fb Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Mon, 11 Aug 2025 12:45:14 -0400 Subject: [PATCH 089/221] mcp: address some comments from #234 (#248) I inadvertently merged #234 without pushing my latest patches. This CL adds the missing changes. --- mcp/cmd_test.go | 13 +++++++++++-- mcp/transport.go | 6 +++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index 82a35a80..fd807cfd 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -129,7 +129,9 @@ func TestServerInterrupt(t *testing.T) { }() // send a signal to the server process to terminate it - cmd.Process.Signal(os.Interrupt) + if err := cmd.Process.Signal(os.Interrupt); err != nil { + t.Fatal(err) + } // wait for the server to exit // TODO: use synctest when available @@ -162,6 +164,11 @@ func TestStdioContextCancellation(t *testing.T) { } // Sleep to make it more likely that the server is blocked in the read loop. + // + // This sleep isn't necessary for the test to pass, but *was* necessary for + // it to fail, before closing was fixed. Unfortunately, it is too invasive a + // change to have the jsonrpc2 package signal across packages when it is + // actually blocked in its read loop. time.Sleep(100 * time.Millisecond) onExit := make(chan struct{}) @@ -170,7 +177,9 @@ func TestStdioContextCancellation(t *testing.T) { close(onExit) }() - cmd.Process.Signal(os.Interrupt) + if err := cmd.Process.Signal(os.Interrupt); err != nil { + t.Fatal(err) + } select { case <-time.After(5 * time.Second): diff --git a/mcp/transport.go b/mcp/transport.go index 02b21806..6ffb67f5 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -318,9 +318,9 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn { // Start a goroutine for reads, so that we can select on the incoming channel // in [ioConn.Read] and unblock the read as soon as Close is called (see #224). // - // This leaks a goroutine, but that is unavoidable since AFAIK there is no - // (easy and portable) way to guarantee that reads of stdin are unblocked - // when closed. + // This leaks a goroutine if rwc.Read does not unblock after it is closed, + // but that is unavoidable since AFAIK there is no (easy and portable) way to + // guarantee that reads of stdin are unblocked when closed. go func() { dec := json.NewDecoder(rwc) for { From dae88535babd1d03698e8d0d49cc5eeda4775c9e Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Mon, 11 Aug 2025 16:43:46 -0400 Subject: [PATCH 090/221] mcp: support stateless streamable sessions (#277) Support stateless streamable sessions by adding a GetSessionID function to StreamableHTTPOptions. If GetSessionID returns "", the session is stateless, and no validation is performed. This is implemented by providing the session a trivial initialization state. To implement this, some parts of #232 (distributed sessions) are copied over, since they add an API for creating an already-initialized session. In total, the following new API is added: - StreamableHTTPOptions.GetSessionID - ServerSessionOptions (a new parameter to Server.Connect) - ServerSessionState - ClientSessionOptions (a new parameter to Client.Connect, for symmetry) For #10 --- README.md | 2 +- examples/client/listfeatures/main.go | 2 +- internal/readme/client/client.go | 2 +- mcp/client.go | 44 +++++++----- mcp/client_list_test.go | 10 +-- mcp/cmd_test.go | 6 +- mcp/conformance_test.go | 2 +- mcp/example_middleware_test.go | 4 +- mcp/logging.go | 2 +- mcp/mcp_test.go | 28 ++++---- mcp/server.go | 101 +++++++++++++++++---------- mcp/server_example_test.go | 8 +-- mcp/session.go | 29 ++++++++ mcp/sse.go | 2 +- mcp/sse_example_test.go | 2 +- mcp/sse_test.go | 2 +- mcp/streamable.go | 59 +++++++++++++--- mcp/streamable_test.go | 80 +++++++++++++++++++-- mcp/transport.go | 39 +++++++---- 19 files changed, 305 insertions(+), 119 deletions(-) create mode 100644 mcp/session.go diff --git a/README.md b/README.md index 76800430..ded7586b 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ func main() { // Connect to a server over stdin/stdout transport := mcp.NewCommandTransport(exec.Command("myserver")) - session, err := client.Connect(ctx, transport) + session, err := client.Connect(ctx, transport, nil) if err != nil { log.Fatal(err) } diff --git a/examples/client/listfeatures/main.go b/examples/client/listfeatures/main.go index caf21bfe..00aa459b 100644 --- a/examples/client/listfeatures/main.go +++ b/examples/client/listfeatures/main.go @@ -41,7 +41,7 @@ func main() { ctx := context.Background() cmd := exec.Command(args[0], args[1:]...) client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) - cs, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) + cs, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil) if err != nil { log.Fatal(err) } diff --git a/internal/readme/client/client.go b/internal/readme/client/client.go index 666ee925..57ec54fa 100644 --- a/internal/readme/client/client.go +++ b/internal/readme/client/client.go @@ -21,7 +21,7 @@ func main() { // Connect to a server over stdin/stdout transport := mcp.NewCommandTransport(exec.Command("myserver")) - session, err := client.Connect(ctx, transport) + session, err := client.Connect(ctx, transport, nil) if err != nil { log.Fatal(err) } diff --git a/mcp/client.go b/mcp/client.go index d3139f54..5798dc5a 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -71,10 +71,11 @@ type ClientOptions struct { // bind implements the binder[*ClientSession] interface, so that Clients can // be connected using [connect]. -func (c *Client) bind(conn *jsonrpc2.Connection) *ClientSession { - cs := &ClientSession{ - conn: conn, - client: c, +func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState) *ClientSession { + assert(mcpConn != nil && conn != nil, "nil connection") + cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c} + if state != nil { + cs.state = *state } c.mu.Lock() defer c.mu.Unlock() @@ -101,6 +102,10 @@ func (e unsupportedProtocolVersionError) Error() string { return fmt.Sprintf("unsupported protocol version: %q", e.version) } +// ClientSessionOptions is reserved for future use. +type ClientSessionOptions struct { +} + // Connect begins an MCP session by connecting to a server over the given // transport, and initializing the session. // @@ -108,8 +113,8 @@ func (e unsupportedProtocolVersionError) Error() string { // when it is no longer needed. However, if the connection is closed by the // server, calls or notifications will return an error wrapping // [ErrConnectionClosed]. -func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, err error) { - cs, err = connect(ctx, t, c) +func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) { + cs, err = connect(ctx, t, c, (*clientSessionState)(nil)) if err != nil { return nil, err } @@ -133,9 +138,9 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e if !slices.Contains(supportedProtocolVersions, res.ProtocolVersion) { return nil, unsupportedProtocolVersionError{res.ProtocolVersion} } - cs.initializeResult = res + cs.state.InitializeResult = res if hc, ok := cs.mcpConn.(clientConnection); ok { - hc.initialized(res) + hc.sessionUpdated(cs.state) } if err := handleNotify(ctx, cs, notificationInitialized, &InitializedParams{}); err != nil { _ = cs.Close() @@ -156,22 +161,25 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e // Call [ClientSession.Close] to close the connection, or await server // termination with [ClientSession.Wait]. type ClientSession struct { - conn *jsonrpc2.Connection - client *Client - initializeResult *InitializeResult - keepaliveCancel context.CancelFunc - mcpConn Connection + conn *jsonrpc2.Connection + client *Client + keepaliveCancel context.CancelFunc + mcpConn Connection + + // No mutex is (currently) required to guard the session state, because it is + // only set synchronously during Client.Connect. + state clientSessionState } -func (cs *ClientSession) setConn(c Connection) { - cs.mcpConn = c +type clientSessionState struct { + InitializeResult *InitializeResult } func (cs *ClientSession) ID() string { - if cs.mcpConn == nil { - return "" + if c, ok := cs.mcpConn.(hasSessionID); ok { + return c.SessionID() } - return cs.mcpConn.SessionID() + return "" } // Close performs a graceful close of the connection, preventing new requests diff --git a/mcp/client_list_test.go b/mcp/client_list_test.go index 5b13a4c8..836d4803 100644 --- a/mcp/client_list_test.go +++ b/mcp/client_list_test.go @@ -38,7 +38,7 @@ func TestList(t *testing.T) { } }) t.Run("iterator", func(t *testing.T) { - testIterator(ctx, t, clientSession.Tools(ctx, nil), wantTools) + testIterator(t, clientSession.Tools(ctx, nil), wantTools) }) }) @@ -60,7 +60,7 @@ func TestList(t *testing.T) { } }) t.Run("iterator", func(t *testing.T) { - testIterator(ctx, t, clientSession.Resources(ctx, nil), wantResources) + testIterator(t, clientSession.Resources(ctx, nil), wantResources) }) }) @@ -81,7 +81,7 @@ func TestList(t *testing.T) { } }) t.Run("ResourceTemplatesIterator", func(t *testing.T) { - testIterator(ctx, t, clientSession.ResourceTemplates(ctx, nil), wantResourceTemplates) + testIterator(t, clientSession.ResourceTemplates(ctx, nil), wantResourceTemplates) }) }) @@ -102,12 +102,12 @@ func TestList(t *testing.T) { } }) t.Run("iterator", func(t *testing.T) { - testIterator(ctx, t, clientSession.Prompts(ctx, nil), wantPrompts) + testIterator(t, clientSession.Prompts(ctx, nil), wantPrompts) }) }) } -func testIterator[T any](ctx context.Context, t *testing.T, seq iter.Seq2[*T, error], want []*T) { +func testIterator[T any](t *testing.T, seq iter.Seq2[*T, error], want []*T) { t.Helper() var got []*T for x, err := range seq { diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index fd807cfd..a021bc86 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -81,7 +81,7 @@ func TestServerRunContextCancel(t *testing.T) { // send a ping to the server to ensure it's running client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) - session, err := client.Connect(ctx, clientTransport) + session, err := client.Connect(ctx, clientTransport, nil) if err != nil { t.Fatal(err) } @@ -116,7 +116,7 @@ func TestServerInterrupt(t *testing.T) { cmd := createServerCommand(t, "default") client := mcp.NewClient(testImpl, nil) - _, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) + _, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil) if err != nil { t.Fatal(err) } @@ -198,7 +198,7 @@ func TestCmdTransport(t *testing.T) { cmd := createServerCommand(t, "default") client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) - session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) + session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil) if err != nil { t.Fatal(err) } diff --git a/mcp/conformance_test.go b/mcp/conformance_test.go index 883d8a89..8e6ea1be 100644 --- a/mcp/conformance_test.go +++ b/mcp/conformance_test.go @@ -135,7 +135,7 @@ func runServerTest(t *testing.T, test *conformanceTest) { // Connect the server, and connect the client stream, // but don't connect an actual client. cTransport, sTransport := NewInMemoryTransports() - ss, err := s.Connect(ctx, sTransport) + ss, err := s.Connect(ctx, sTransport, nil) if err != nil { t.Fatal(err) } diff --git a/mcp/example_middleware_test.go b/mcp/example_middleware_test.go index 597b9dcd..c91250c3 100644 --- a/mcp/example_middleware_test.go +++ b/mcp/example_middleware_test.go @@ -114,10 +114,10 @@ func Example_loggingMiddleware() { ctx := context.Background() // Connect server and client - serverSession, _ := server.Connect(ctx, serverTransport) + serverSession, _ := server.Connect(ctx, serverTransport, nil) defer serverSession.Close() - clientSession, _ := client.Connect(ctx, clientTransport) + clientSession, _ := client.Connect(ctx, clientTransport, nil) defer clientSession.Close() // Call the tool to demonstrate logging diff --git a/mcp/logging.go b/mcp/logging.go index 4880e179..4d33097a 100644 --- a/mcp/logging.go +++ b/mcp/logging.go @@ -117,7 +117,7 @@ func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool { // This is also checked in ServerSession.LoggingMessage, so checking it here // is just an optimization that skips building the JSON. h.ss.mu.Lock() - mcpLevel := h.ss.logLevel + mcpLevel := h.ss.state.LogLevel h.ss.mu.Unlock() return level >= mcpLevelToSlog(mcpLevel) } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 48e95de2..9e9a6a30 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -104,7 +104,7 @@ func TestEndToEnd(t *testing.T) { s.AddResource(resource2, readHandler) // Connect the server. - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -148,7 +148,7 @@ func TestEndToEnd(t *testing.T) { c.AddRoots(&Root{URI: "file://" + rootAbs}) // Connect the client. - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -549,13 +549,13 @@ func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *Clien if config != nil { config(s) } - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -598,7 +598,7 @@ func TestBatching(t *testing.T) { ct, st := NewInMemoryTransports() s := NewServer(testImpl, nil) - _, err := s.Connect(ctx, st) + _, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -608,7 +608,7 @@ func TestBatching(t *testing.T) { // 'initialize' to block. Therefore, we can only test with a size of 1. // Since batching is being removed, we can probably just delete this. const batchSize = 1 - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -668,7 +668,7 @@ func TestMiddleware(t *testing.T) { ct, st := NewInMemoryTransports() s := NewServer(testImpl, nil) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -695,7 +695,7 @@ func TestMiddleware(t *testing.T) { c.AddSendingMiddleware(traceCalls[*ClientSession](&cbuf, "S1"), traceCalls[*ClientSession](&cbuf, "S2")) c.AddReceivingMiddleware(traceCalls[*ClientSession](&cbuf, "R1"), traceCalls[*ClientSession](&cbuf, "R2")) - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -777,13 +777,13 @@ func TestNoJSONNull(t *testing.T) { ct = NewLoggingTransport(ct, &logbuf) s := NewServer(testImpl, nil) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -845,7 +845,7 @@ func TestKeepAlive(t *testing.T) { s := NewServer(testImpl, serverOpts) AddTool(s, greetTool(), sayHi) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -855,7 +855,7 @@ func TestKeepAlive(t *testing.T) { KeepAlive: 100 * time.Millisecond, } c := NewClient(testImpl, clientOpts) - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -889,7 +889,7 @@ func TestKeepAliveFailure(t *testing.T) { // Server without keepalive (to test one-sided keepalive) s := NewServer(testImpl, nil) AddTool(s, greetTool(), sayHi) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -899,7 +899,7 @@ func TestKeepAliveFailure(t *testing.T) { KeepAlive: 50 * time.Millisecond, } c := NewClient(testImpl, clientOpts) - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } diff --git a/mcp/server.go b/mcp/server.go index d81dec60..89f3b6c9 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -301,7 +301,7 @@ func (s *Server) listPrompts(_ context.Context, _ *ServerSession, params *ListPr }) } -func (s *Server) getPrompt(ctx context.Context, cc *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { +func (s *Server) getPrompt(ctx context.Context, ss *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { s.mu.Lock() prompt, ok := s.prompts.get(params.Name) s.mu.Unlock() @@ -309,7 +309,7 @@ func (s *Server) getPrompt(ctx context.Context, cc *ServerSession, params *GetPr // TODO: surface the error code over the wire, instead of flattening it into the string. return nil, fmt.Errorf("%s: unknown prompt %q", jsonrpc2.ErrInvalidParams, params.Name) } - return prompt.handler(ctx, cc, params) + return prompt.handler(ctx, ss, params) } func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListToolsParams) (*ListToolsResult, error) { @@ -518,7 +518,7 @@ func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *Uns // It need not be called on servers that are used for multiple concurrent connections, // as with [StreamableHTTPHandler]. func (s *Server) Run(ctx context.Context, t Transport) error { - ss, err := s.Connect(ctx, t) + ss, err := s.Connect(ctx, t, nil) if err != nil { return err } @@ -539,8 +539,12 @@ func (s *Server) Run(ctx context.Context, t Transport) error { // bind implements the binder[*ServerSession] interface, so that Servers can // be connected using [connect]. -func (s *Server) bind(conn *jsonrpc2.Connection) *ServerSession { - ss := &ServerSession{conn: conn, server: s} +func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState) *ServerSession { + assert(mcpConn != nil && conn != nil, "nil connection") + ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s} + if state != nil { + ss.state = *state + } s.mu.Lock() s.sessions = append(s.sessions, ss) s.mu.Unlock() @@ -561,32 +565,50 @@ func (s *Server) disconnect(cc *ServerSession) { } } +// ServerSessionOptions configures the server session. +type ServerSessionOptions struct { + State *ServerSessionState +} + // Connect connects the MCP server over the given transport and starts handling // messages. // // It returns a connection object that may be used to terminate the connection // (with [Connection.Close]), or await client termination (with // [Connection.Wait]). -func (s *Server) Connect(ctx context.Context, t Transport) (*ServerSession, error) { - return connect(ctx, t, s) +// +// If opts.State is non-nil, it is the initial state for the server. +func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) { + var state *ServerSessionState + if opts != nil { + state = opts.State + } + return connect(ctx, t, s, state) } +// TODO: (nit) move all ServerSession methods below the ServerSession declaration. func (ss *ServerSession) initialized(ctx context.Context, params *InitializedParams) (Result, error) { + if params == nil { + // Since we use nilness to signal 'initialized' state, we must ensure that + // params are non-nil. + params = new(InitializedParams) + } if ss.server.opts.KeepAlive > 0 { ss.startKeepalive(ss.server.opts.KeepAlive) } - ss.mu.Lock() - hasParams := ss.initializeParams != nil - wasInitialized := ss._initialized - if hasParams { - ss._initialized = true - } - ss.mu.Unlock() + var wasInit, wasInitd bool + ss.updateState(func(state *ServerSessionState) { + wasInit = state.InitializeParams != nil + wasInitd = state.InitializedParams != nil + if wasInit && !wasInitd { + state.InitializedParams = params + } + }) - if !hasParams { + if !wasInit { return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize) } - if wasInitialized { + if wasInitd { return nil, fmt.Errorf("duplicate %q received", notificationInitialized) } return callNotificationHandler(ctx, ss.server.opts.InitializedHandler, ss, params) @@ -615,25 +637,30 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot // Call [ServerSession.Close] to close the connection, or await client // termination with [ServerSession.Wait]. type ServerSession struct { - server *Server - conn *jsonrpc2.Connection - mcpConn Connection - mu sync.Mutex - logLevel LoggingLevel - initializeParams *InitializeParams - _initialized bool - keepaliveCancel context.CancelFunc + server *Server + conn *jsonrpc2.Connection + mcpConn Connection + keepaliveCancel context.CancelFunc // TODO: theory around why keepaliveCancel need not be guarded + + mu sync.Mutex + state ServerSessionState } -func (ss *ServerSession) setConn(c Connection) { - ss.mcpConn = c +func (ss *ServerSession) updateState(mut func(*ServerSessionState)) { + ss.mu.Lock() + mut(&ss.state) + copy := ss.state + ss.mu.Unlock() + if c, ok := ss.mcpConn.(serverConnection); ok { + c.sessionUpdated(copy) + } } func (ss *ServerSession) ID() string { - if ss.mcpConn == nil { - return "" + if c, ok := ss.mcpConn.(hasSessionID); ok { + return c.SessionID() } - return ss.mcpConn.SessionID() + return "" } // Ping pings the client. @@ -657,7 +684,7 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag // is below that of the last SetLevel. func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) error { ss.mu.Lock() - logLevel := ss.logLevel + logLevel := ss.state.LogLevel ss.mu.Unlock() if logLevel == "" { // The spec is unclear, but seems to imply that no log messages are sent until the client @@ -747,7 +774,7 @@ func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn } // handle invokes the method described by the given JSON RPC request. func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { ss.mu.Lock() - initialized := ss._initialized + initialized := ss.state.InitializedParams != nil ss.mu.Unlock() // From the spec: // "The client SHOULD NOT send requests other than pings before the server @@ -770,9 +797,9 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam if params == nil { return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) } - ss.mu.Lock() - ss.initializeParams = params - ss.mu.Unlock() + ss.updateState(func(state *ServerSessionState) { + state.InitializeParams = params + }) // If we support the client's version, reply with it. Otherwise, reply with our // latest version. @@ -796,9 +823,9 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error } func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*emptyResult, error) { - ss.mu.Lock() - defer ss.mu.Unlock() - ss.logLevel = params.Level + ss.updateState(func(state *ServerSessionState) { + state.LogLevel = params.Level + }) return &emptyResult{}, nil } diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index 3ab7a2a4..241008e9 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -31,13 +31,13 @@ func ExampleServer() { server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v0.0.1"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) - serverSession, err := server.Connect(ctx, serverTransport) + serverSession, err := server.Connect(ctx, serverTransport, nil) if err != nil { log.Fatal(err) } client := mcp.NewClient(&mcp.Implementation{Name: "client"}, nil) - clientSession, err := client.Connect(ctx, clientTransport) + clientSession, err := client.Connect(ctx, clientTransport, nil) if err != nil { log.Fatal(err) } @@ -62,11 +62,11 @@ func createSessions(ctx context.Context) (*mcp.ClientSession, *mcp.ServerSession server := mcp.NewServer(testImpl, nil) client := mcp.NewClient(testImpl, nil) serverTransport, clientTransport := mcp.NewInMemoryTransports() - serverSession, err := server.Connect(ctx, serverTransport) + serverSession, err := server.Connect(ctx, serverTransport, nil) if err != nil { log.Fatal(err) } - clientSession, err := client.Connect(ctx, clientTransport) + clientSession, err := client.Connect(ctx, clientTransport, nil) if err != nil { log.Fatal(err) } diff --git a/mcp/session.go b/mcp/session.go new file mode 100644 index 00000000..dcf9888c --- /dev/null +++ b/mcp/session.go @@ -0,0 +1,29 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +// hasSessionID is the interface which, if implemented by connections, informs +// the session about their session ID. +// +// TODO(rfindley): remove SessionID methods from connections, when it doesn't +// make sense. Or remove it from the Sessions entirely: why does it even need +// to be exposed? +type hasSessionID interface { + SessionID() string +} + +// ServerSessionState is the state of a session. +type ServerSessionState struct { + // InitializeParams are the parameters from 'initialize'. + InitializeParams *InitializeParams `json:"initializeParams"` + + // InitializedParams are the parameters from 'notifications/initialized'. + InitializedParams *InitializedParams `json:"initializedParams"` + + // LogLevel is the logging level for the session. + LogLevel LoggingLevel `json:"logLevel"` + + // TODO: resource subscriptions +} diff --git a/mcp/sse.go b/mcp/sse.go index bdc4770b..f74a3fb6 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -221,7 +221,7 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { http.Error(w, "no server available", http.StatusBadRequest) return } - ss, err := server.Connect(req.Context(), transport) + ss, err := server.Connect(req.Context(), transport, nil) if err != nil { http.Error(w, "connection failed", http.StatusInternalServerError) return diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index d8ce939b..9a7c8ae7 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -37,7 +37,7 @@ func ExampleSSEHandler() { ctx := context.Background() transport := mcp.NewSSEClientTransport(httpServer.URL, nil) client := mcp.NewClient(&mcp.Implementation{Name: "test", Version: "v1.0.0"}, nil) - cs, err := client.Connect(ctx, transport) + cs, err := client.Connect(ctx, transport, nil) if err != nil { log.Fatal(err) } diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 35fdbdbf..79cfacc3 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -49,7 +49,7 @@ func TestSSEServer(t *testing.T) { }) c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, clientTransport) + cs, err := c.Connect(ctx, clientTransport, nil) if err != nil { t.Fatal(err) } diff --git a/mcp/streamable.go b/mcp/streamable.go index 108de5d2..72ac7e83 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -42,9 +42,14 @@ type StreamableHTTPHandler struct { sessions map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) } -// StreamableHTTPOptions is a placeholder options struct for future -// configuration of the StreamableHTTP handler. +// StreamableHTTPOptions configures the StreamableHTTPHandler. type StreamableHTTPOptions struct { + // GetSessionID provides the next session ID to use for an incoming request. + // + // If GetSessionID returns an empty string, the session is 'stateless', + // meaning it is not persisted and no session validation is performed. + GetSessionID func() string + // TODO: support configurable session ID generation (?) // TODO: support session retention (?) @@ -66,6 +71,9 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea if opts != nil { h.opts = *opts } + if h.opts.GetSessionID == nil { + h.opts.GetSessionID = randText + } return h } @@ -138,6 +146,10 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque switch req.Method { case http.MethodPost, http.MethodGet: + if req.Method == http.MethodGet && session == nil { + http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed) + return + } default: w.Header().Set("Allow", "GET, POST") http.Error(w, "unsupported method", http.StatusMethodNotAllowed) @@ -145,23 +157,42 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque } if session == nil { - s := NewStreamableServerTransport(randText(), h.opts.transportOptions) server := h.getServer(req) if server == nil { // The getServer argument to NewStreamableHTTPHandler returned nil. http.Error(w, "no server available", http.StatusBadRequest) return } + sessionID := h.opts.GetSessionID() + s := NewStreamableServerTransport(sessionID, h.opts.transportOptions) + + // To support stateless mode, we initialize the session with a default + // state, so that it doesn't reject subsequent requests. + var connectOpts *ServerSessionOptions + if sessionID == "" { + connectOpts = &ServerSessionOptions{ + State: &ServerSessionState{ + InitializeParams: new(InitializeParams), + InitializedParams: new(InitializedParams), + }, + } + } // Pass req.Context() here, to allow middleware to add context values. // The context is detached in the jsonrpc2 library when handling the // long-running stream. - if _, err := server.Connect(req.Context(), s); err != nil { + ss, err := server.Connect(req.Context(), s, connectOpts) + if err != nil { http.Error(w, "failed connection", http.StatusInternalServerError) return } - h.sessionsMu.Lock() - h.sessions[s.sessionID] = s - h.sessionsMu.Unlock() + if sessionID == "" { + // Stateless mode: close the session when the request exits. + defer ss.Close() // close the fake session after handling the request + } else { + h.sessionsMu.Lock() + h.sessions[s.sessionID] = s + h.sessionsMu.Unlock() + } session = s } @@ -485,7 +516,9 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R func (t *StreamableServerTransport) respondJSON(stream *stream, w http.ResponseWriter, req *http.Request) { w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Content-Type", "application/json") - w.Header().Set(sessionIDHeader, t.sessionID) + if t.sessionID != "" { + w.Header().Set(sessionIDHeader, t.sessionID) + } var msgs []json.RawMessage ctx := req.Context() @@ -524,7 +557,9 @@ func (t *StreamableServerTransport) respondSSE(stream *stream, w http.ResponseWr w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] w.Header().Set("Connection", "keep-alive") - w.Header().Set(sessionIDHeader, t.sessionID) + if t.sessionID != "" { + w.Header().Set(sessionIDHeader, t.sessionID) + } // write one event containing data. write := func(data []byte) bool { @@ -893,9 +928,11 @@ type streamableClientConn struct { sessionID string } -func (c *streamableClientConn) initialized(res *InitializeResult) { +var _ clientConnection = (*streamableClientConn)(nil) + +func (c *streamableClientConn) sessionUpdated(state clientSessionState) { c.mu.Lock() - c.initializedResult = res + c.initializedResult = state.InitializeResult c.mu.Unlock() // Start the persistent SSE listener as soon as we have the initialized diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 24368b00..54803939 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -81,7 +81,7 @@ func TestStreamableTransports(t *testing.T) { HTTPClient: httpClient, }) client := NewClient(testImpl, nil) - session, err := client.Connect(ctx, transport) + session, err := client.Connect(ctx, transport, nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) } @@ -173,7 +173,7 @@ func TestClientReplay(t *testing.T) { notifications <- params.Message }, }) - clientSession, err := client.Connect(ctx, NewStreamableClientTransport(proxy.URL, nil)) + clientSession, err := client.Connect(ctx, NewStreamableClientTransport(proxy.URL, nil), nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) } @@ -239,7 +239,7 @@ func TestServerInitiatedSSE(t *testing.T) { notifications <- "toolListChanged" }, }) - clientSession, err := client.Connect(ctx, NewStreamableClientTransport(httpServer.URL, nil)) + clientSession, err := client.Connect(ctx, NewStreamableClientTransport(httpServer.URL, nil), nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) } @@ -767,12 +767,12 @@ func TestStreamableClientTransportApplicationJSON(t *testing.T) { transport := NewStreamableClientTransport(httpServer.URL, nil) client := NewClient(testImpl, nil) - session, err := client.Connect(ctx, transport) + session, err := client.Connect(ctx, transport, nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) } defer session.Close() - if diff := cmp.Diff(initResult, session.initializeResult); diff != "" { + if diff := cmp.Diff(initResult, session.state.InitializeResult); diff != "" { t.Errorf("mismatch (-want, +got):\n%s", diff) } } @@ -821,3 +821,73 @@ func TestEventID(t *testing.T) { }) } } +func TestStreamableStateless(t *testing.T) { + // Test stateless mode behavior + ctx := context.Background() + + // This version of sayHi doesn't make a ping request (we can't respond to + // that request from our client). + sayHi := func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[hiParams]) (*CallToolResultFor[any], error) { + return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + params.Arguments.Name}}}, nil + } + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + + // Test stateless mode. + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + GetSessionID: func() string { return "" }, + }) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + checkRequest := func(body string) { + // Verify we can call tools/list directly without initialization in stateless mode + req, err := http.NewRequestWithContext(ctx, http.MethodPost, httpServer.URL, strings.NewReader(body)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + // Verify that no session ID header is returned in stateless mode + if sessionID := resp.Header.Get(sessionIDHeader); sessionID != "" { + t.Errorf("%s = %s, want no session ID header", sessionIDHeader, sessionID) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("Status code = %d; want successful response", resp.StatusCode) + } + + var events []Event + for event, err := range scanEvents(resp.Body) { + if err != nil { + t.Fatal(err) + } + events = append(events, event) + } + if len(events) != 1 { + t.Fatalf("got %d SSE events, want 1; events: %v", len(events), events) + } + msg, err := jsonrpc.DecodeMessage(events[0].Data) + if err != nil { + t.Fatal(err) + } + jsonResp, ok := msg.(*jsonrpc.Response) + if !ok { + t.Errorf("event is %T, want response", jsonResp) + } + if jsonResp.Error != nil { + t.Errorf("request failed: %v", jsonResp.Error) + } + } + + checkRequest(`{"jsonrpc":"2.0","method":"tools/list","id":1,"params":{}}`) + + // Verify we can make another request without session ID + checkRequest(`{"jsonrpc":"2.0","method":"tools/call","id":2,"params":{"name":"greet","arguments":{"name":"World"}}}`) +} diff --git a/mcp/transport.go b/mcp/transport.go index 6ffb67f5..f45d7d2b 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -60,11 +60,28 @@ type Connection interface { SessionID() string } -// A clientConnection is a [Connection] that is specific to the MCP client, and -// so may receive information about the client session. +// A ClientConnection is a [Connection] that is specific to the MCP client. +// +// If client connections implement this interface, they may receive information +// about changes to the client session. +// +// TODO: should this interface be exported? type clientConnection interface { Connection - initialized(*InitializeResult) + + // SessionUpdated is called whenever the client session state changes. + sessionUpdated(clientSessionState) +} + +// A serverConnection is a Connection that is specific to the MCP server. +// +// If server connections implement this interface, they receive information +// about changes to the server session. +// +// TODO: should this interface be exported? +type serverConnection interface { + Connection + sessionUpdated(ServerSessionState) } // A StdioTransport is a [Transport] that communicates over stdin/stdout using @@ -102,37 +119,36 @@ func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) { return &InMemoryTransport{ioTransport{c1}}, &InMemoryTransport{ioTransport{c2}} } -type binder[T handler] interface { - bind(*jsonrpc2.Connection) T +type binder[T handler, State any] interface { + bind(Connection, *jsonrpc2.Connection, State) T disconnect(T) } type handler interface { handle(ctx context.Context, req *jsonrpc.Request) (any, error) - setConn(Connection) } -func connect[H handler](ctx context.Context, t Transport, b binder[H]) (H, error) { +func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State) (H, error) { var zero H - conn, err := t.Connect(ctx) + mcpConn, err := t.Connect(ctx) if err != nil { return zero, err } // If logging is configured, write message logs. - reader, writer := jsonrpc2.Reader(conn), jsonrpc2.Writer(conn) + reader, writer := jsonrpc2.Reader(mcpConn), jsonrpc2.Writer(mcpConn) var ( h H preempter canceller ) bind := func(conn *jsonrpc2.Connection) jsonrpc2.Handler { - h = b.bind(conn) + h = b.bind(mcpConn, conn, s) preempter.conn = conn return jsonrpc2.HandlerFunc(h.handle) } _ = jsonrpc2.NewConnection(ctx, jsonrpc2.ConnectionConfig{ Reader: reader, Writer: writer, - Closer: conn, + Closer: mcpConn, Bind: bind, Preempter: &preempter, OnDone: func() { @@ -141,7 +157,6 @@ func connect[H handler](ctx context.Context, t Transport, b binder[H]) (H, error OnInternalError: func(err error) { log.Printf("jsonrpc2 error: %v", err) }, }) assert(preempter.conn != nil, "unbound preempter") - h.setConn(conn) return h, nil } From 679f777825cc6c49e2248549917d6500c5f149a2 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Tue, 12 Aug 2025 14:14:57 -0400 Subject: [PATCH 091/221] Update pull_request_template.md (#287) pull_request_template: update template to reflect correct form --- .github/pull_request_template.md | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index be13052c..3a0b3d8c 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,12 +1,28 @@ -### PR Tips +### PR Guideline Typically, PRs should consist of a single commit, and so should generally follow -the [rules for Go commit messages](https://go.dev/wiki/CommitMessage), with the following -changes and additions: +the [rules for Go commit messages](https://go.dev/wiki/CommitMessage). -- Markdown is allowed. +You **must** follow the form: -- For a pervasive change, use "all" in the title instead of a package name. +``` +net/http: handle foo when bar + +[longer description here in the body] + +Fixes #12345 +``` +Notably, for the subject (the first line of description): +- the name of the package affected by the change goes before the colon +- the part after the colon uses the verb tense + phrase that completes the blank in, “this change modifies this package to ___________” +- the verb after the colon is lowercase +- there is no trailing period +- it should be kept as short as possible + +Additionally: + +- Markdown is allowed. +- For a pervasive change, use "all" in the title instead of a package name. - The PR description should provide context (why this change?) and describe the changes at a high level. Changes that are obvious from the diffs don't need to be mentioned. From e097918484536115cefb48e605eb1f37a3598d8e Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Tue, 12 Aug 2025 15:14:21 -0400 Subject: [PATCH 092/221] mcp: make transports open structs As described in #272, there is really no reason for transports to be closed structs with constructors, since their state must be established by a call to Connect. Making them open structs simplifies their APIs, and means that all transports can be extended in the future: we don't have to create empty Options structs just for the purpose of future compatibility. For now, the related constructors and options structs are simply deprecated (with go:fix directives where possible). A future CL will remove them prior to the v1.0.0 release. For #272 --- README.md | 4 +- design/design.md | 81 +++---- examples/client/listfeatures/main.go | 2 +- examples/server/hello/main.go | 2 +- examples/server/memory/main.go | 2 +- examples/server/sequentialthinking/main.go | 2 +- internal/readme/client/client.go | 2 +- internal/readme/server/server.go | 2 +- mcp/cmd.go | 16 +- mcp/cmd_test.go | 8 +- mcp/mcp_test.go | 2 +- mcp/sse.go | 106 ++++---- mcp/sse_example_test.go | 2 +- mcp/sse_test.go | 5 +- mcp/streamable.go | 266 ++++++++++++--------- mcp/streamable_test.go | 13 +- mcp/transport.go | 43 ++-- 17 files changed, 307 insertions(+), 251 deletions(-) diff --git a/README.md b/README.md index ded7586b..22a3fed3 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ func main() { client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) // Connect to a server over stdin/stdout - transport := mcp.NewCommandTransport(exec.Command("myserver")) + transport := &mcp.CommandTransport{Command: exec.Command("myserver")} session, err := client.Connect(ctx, transport, nil) if err != nil { log.Fatal(err) @@ -127,7 +127,7 @@ func main() { mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) // Run the server over stdin/stdout, until the client disconnects - if err := server.Run(context.Background(), mcp.NewStdioTransport()); err != nil { + if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { log.Fatal(err) } } diff --git a/design/design.md b/design/design.md index 93dc5521..65a2b61f 100644 --- a/design/design.md +++ b/design/design.md @@ -100,11 +100,7 @@ The `CommandTransport` is the client side of the stdio transport, and connects b ```go // A CommandTransport is a [Transport] that runs a command and communicates // with it over stdin/stdout, using newline-delimited JSON. -type CommandTransport struct { /* unexported fields */ } - -// NewCommandTransport returns a [CommandTransport] that runs the given command -// and communicates with it over stdin/stdout. -func NewCommandTransport(cmd *exec.Command) *CommandTransport +type CommandTransport struct { Command *exec.Command } // Connect starts the command, and connects to it over stdin/stdout. func (*CommandTransport) Connect(ctx context.Context) (Connection, error) { @@ -115,9 +111,7 @@ The `StdioTransport` is the server side of the stdio transport, and connects by ```go // A StdioTransport is a [Transport] that communicates using newline-delimited // JSON over stdin/stdout. -type StdioTransport struct { /* unexported fields */ } - -func NewStdioTransport() *StdioTransport +type StdioTransport struct { } func (t *StdioTransport) Connect(context.Context) (Connection, error) ``` @@ -128,6 +122,8 @@ The HTTP transport APIs are even more asymmetrical. Since connections are initia Importantly, since they serve many connections, the HTTP handlers must accept a callback to get an MCP server for each new session. As described below, MCP servers can optionally connect to multiple clients. This allows customization of per-session servers: if the MCP server is stateless, the user can return the same MCP server for each connection. On the other hand, if any per-session customization is required, it is possible by returning a different `Server` instance for each connection. +Both the SSE and Streamable HTTP server transports are http.Handlers which serve messages to their associated connection. Consequently, they can be connected at most once. + ```go // SSEHTTPHandler is an http.Handler that serves SSE-based MCP sessions as defined by // the 2024-11-05 version of the MCP protocol. @@ -153,26 +149,10 @@ By default, the SSE handler creates messages endpoints with the `?sessionId=...` ```go // A SSEServerTransport is a logical SSE session created through a hanging GET // request. -// -// When connected, it returns the following [Connection] implementation: -// - Writes are SSE 'message' events to the GET response. -// - Reads are received from POSTs to the session endpoint, via -// [SSEServerTransport.ServeHTTP]. -// - Close terminates the hanging GET. -type SSEServerTransport struct { /* ... */ } - -// NewSSEServerTransport creates a new SSE transport for the given messages -// endpoint, and hanging GET response. -// -// Use [SSEServerTransport.Connect] to initiate the flow of messages. -// -// The transport is itself an [http.Handler]. It is the caller's responsibility -// to ensure that the resulting transport serves HTTP requests on the given -// session endpoint. -// -// Most callers should instead use an [SSEHandler], which transparently handles -// the delegation to SSEServerTransports. -func NewSSEServerTransport(endpoint string, w http.ResponseWriter) *SSEServerTransport +type SSEServerTransport struct { + Endpoint string + Response http.ResponseWriter +} // ServeHTTP handles POST requests to the transport endpoint. func (*SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) @@ -185,20 +165,14 @@ func (*SSEServerTransport) Connect(context.Context) (Connection, error) The SSE client transport is simpler, and hopefully self-explanatory. ```go -type SSEClientTransport struct { /* ... */ } - -// SSEClientTransportOptions provides options for the [NewSSEClientTransport] -// constructor. -type SSEClientTransportOptions struct { +type SSEClientTransport struct { + // Endpoint is the SSE endpoint to connect to. + Endpoint string // HTTPClient is the client to use for making HTTP requests. If nil, // http.DefaultClient is used. HTTPClient *http.Client } -// NewSSEClientTransport returns a new client transport that connects to the -// SSE server at the provided URL. -func NewSSEClientTransport(url string, opts *SSEClientTransportOptions) (*SSEClientTransport, error) - // Connect connects through the client endpoint. func (*SSEClientTransport) Connect(ctx context.Context) (Connection, error) ``` @@ -218,23 +192,22 @@ func (*StreamableHTTPHandler) Close() error // session ID, not an endpoint, along with the HTTP response for the request // that created the session. It is the caller's responsibility to delegate // requests to this session. -type StreamableServerTransport struct { /* ... */ } -func NewStreamableServerTransport(sessionID string) *StreamableServerTransport +type StreamableServerTransport struct { + // SessionID is the ID of this session. + SessionID string + // Storage for events, to enable stream resumption. + EventStore EventStore +} func (*StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) func (*StreamableServerTransport) Connect(context.Context) (Connection, error) // The streamable client handles reconnection transparently to the user. -type StreamableClientTransport struct { /* ... */ } - -// StreamableClientTransportOptions provides options for the -// [NewStreamableClientTransport] constructor. -type StreamableClientTransportOptions struct { - // HTTPClient is the client to use for making HTTP requests. If nil, - // http.DefaultClient is used. - HTTPClient *http.Client +type StreamableClientTransport struct { + Endpoint string + HTTPClient *http.Client + ReconnectOptions *StreamableReconnectOptions } -func NewStreamableClientTransport(url string, opts *StreamableClientTransportOptions) *StreamableClientTransport func (*StreamableClientTransport) Connect(context.Context) (Connection, error) ``` @@ -257,8 +230,10 @@ func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) // A LoggingTransport is a [Transport] that delegates to another transport, // writing RPC logs to an io.Writer. -type LoggingTransport struct { /* ... */ } -func NewLoggingTransport(delegate Transport, w io.Writer) *LoggingTransport +type LoggingTransport struct { + Delegate Transport + Writer io.Writer +} ``` ### Protocol types @@ -358,7 +333,9 @@ Here's an example of these APIs from the client side: ```go client := mcp.NewClient(&mcp.Implementation{Name:"mcp-client", Version:"v1.0.0"}, nil) // Connect to a server over stdin/stdout -transport := mcp.NewCommandTransport(exec.Command("myserver")) +transport := &mcp.CommandTransport{ + Command: exec.Command("myserver"}, +} session, err := client.Connect(ctx, transport) if err != nil { ... } // Call a tool on the server. @@ -374,7 +351,7 @@ A server that can handle that client call would look like this: server := mcp.NewServer(&mcp.Implementation{Name:"greeter", Version:"v1.0.0"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) // Run the server over stdin/stdout, until the client disconnects. -if err := server.Run(context.Background(), mcp.NewStdioTransport()); err != nil { +if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { log.Fatal(err) } ``` diff --git a/examples/client/listfeatures/main.go b/examples/client/listfeatures/main.go index 00aa459b..755a4f98 100644 --- a/examples/client/listfeatures/main.go +++ b/examples/client/listfeatures/main.go @@ -41,7 +41,7 @@ func main() { ctx := context.Background() cmd := exec.Command(args[0], args[1:]...) client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) - cs, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil) + cs, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil) if err != nil { log.Fatal(err) } diff --git a/examples/server/hello/main.go b/examples/server/hello/main.go index f01a6c99..00fb37a6 100644 --- a/examples/server/hello/main.go +++ b/examples/server/hello/main.go @@ -58,7 +58,7 @@ func main() { log.Printf("MCP handler listening at %s", *httpAddr) http.ListenAndServe(*httpAddr, handler) } else { - t := mcp.NewLoggingTransport(mcp.NewStdioTransport(), os.Stderr) + t := &mcp.LoggingTransport{Transport: &mcp.StdioTransport{}, Writer: os.Stderr} if err := server.Run(context.Background(), t); err != nil { log.Printf("Server failed: %v", err) } diff --git a/examples/server/memory/main.go b/examples/server/memory/main.go index 61ab1060..99d109a6 100644 --- a/examples/server/memory/main.go +++ b/examples/server/memory/main.go @@ -137,7 +137,7 @@ func main() { log.Printf("MCP handler listening at %s", *httpAddr) http.ListenAndServe(*httpAddr, handler) } else { - t := mcp.NewLoggingTransport(mcp.NewStdioTransport(), os.Stderr) + t := &mcp.LoggingTransport{Transport: &mcp.StdioTransport{}, Writer: os.Stderr} if err := server.Run(context.Background(), t); err != nil { log.Printf("Server failed: %v", err) } diff --git a/examples/server/sequentialthinking/main.go b/examples/server/sequentialthinking/main.go index 4830bb0a..e9cb594d 100644 --- a/examples/server/sequentialthinking/main.go +++ b/examples/server/sequentialthinking/main.go @@ -536,7 +536,7 @@ func main() { log.Fatal(err) } } else { - t := mcp.NewLoggingTransport(mcp.NewStdioTransport(), os.Stderr) + t := &mcp.LoggingTransport{Transport: &mcp.StdioTransport{}, Writer: os.Stderr} if err := server.Run(context.Background(), t); err != nil { log.Printf("Server failed: %v", err) } diff --git a/internal/readme/client/client.go b/internal/readme/client/client.go index 57ec54fa..e2794f8b 100644 --- a/internal/readme/client/client.go +++ b/internal/readme/client/client.go @@ -20,7 +20,7 @@ func main() { client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) // Connect to a server over stdin/stdout - transport := mcp.NewCommandTransport(exec.Command("myserver")) + transport := &mcp.CommandTransport{Command: exec.Command("myserver")} session, err := client.Connect(ctx, transport, nil) if err != nil { log.Fatal(err) diff --git a/internal/readme/server/server.go b/internal/readme/server/server.go index 8901e773..3746e194 100644 --- a/internal/readme/server/server.go +++ b/internal/readme/server/server.go @@ -28,7 +28,7 @@ func main() { mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) // Run the server over stdin/stdout, until the client disconnects - if err := server.Run(context.Background(), mcp.NewStdioTransport()); err != nil { + if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { log.Fatal(err) } } diff --git a/mcp/cmd.go b/mcp/cmd.go index 163bb0ca..5ec8c9e7 100644 --- a/mcp/cmd.go +++ b/mcp/cmd.go @@ -16,7 +16,7 @@ import ( // A CommandTransport is a [Transport] that runs a command and communicates // with it over stdin/stdout, using newline-delimited JSON. type CommandTransport struct { - cmd *exec.Cmd + Command *exec.Cmd } // NewCommandTransport returns a [CommandTransport] that runs the given command @@ -24,25 +24,29 @@ type CommandTransport struct { // // The resulting transport takes ownership of the command, starting it during // [CommandTransport.Connect], and stopping it when the connection is closed. +// +// Deprecated: use a CommandTransport literal. +// +//go:fix inline func NewCommandTransport(cmd *exec.Cmd) *CommandTransport { - return &CommandTransport{cmd} + return &CommandTransport{Command: cmd} } // Connect starts the command, and connects to it over stdin/stdout. func (t *CommandTransport) Connect(ctx context.Context) (Connection, error) { - stdout, err := t.cmd.StdoutPipe() + stdout, err := t.Command.StdoutPipe() if err != nil { return nil, err } stdout = io.NopCloser(stdout) // close the connection by closing stdin, not stdout - stdin, err := t.cmd.StdinPipe() + stdin, err := t.Command.StdinPipe() if err != nil { return nil, err } - if err := t.cmd.Start(); err != nil { + if err := t.Command.Start(); err != nil { return nil, err } - return newIOConn(&pipeRWC{t.cmd, stdout, stdin}), nil + return newIOConn(&pipeRWC{t.Command, stdout, stdin}), nil } // A pipeRWC is an io.ReadWriteCloser that communicates with a subprocess over diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index a021bc86..6c3a1a76 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -49,7 +49,7 @@ func runServer() { server := mcp.NewServer(testImpl, nil) mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) - if err := server.Run(ctx, mcp.NewStdioTransport()); err != nil { + if err := server.Run(ctx, &mcp.StdioTransport{}); err != nil { log.Fatal(err) } } @@ -59,7 +59,7 @@ func runCancelContextServer() { defer done() server := mcp.NewServer(testImpl, nil) - if err := server.Run(ctx, mcp.NewStdioTransport()); err != nil { + if err := server.Run(ctx, &mcp.StdioTransport{}); err != nil { log.Fatal(err) } } @@ -116,7 +116,7 @@ func TestServerInterrupt(t *testing.T) { cmd := createServerCommand(t, "default") client := mcp.NewClient(testImpl, nil) - _, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil) + _, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil) if err != nil { t.Fatal(err) } @@ -198,7 +198,7 @@ func TestCmdTransport(t *testing.T) { cmd := createServerCommand(t, "default") client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) - session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd), nil) + session, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil) if err != nil { t.Fatal(err) } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 9e9a6a30..9e3eccd0 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -774,7 +774,7 @@ func TestNoJSONNull(t *testing.T) { // Collect logs, to sanity check that we don't write JSON null anywhere. var logbuf safeBuffer - ct = NewLoggingTransport(ct, &logbuf) + ct = &LoggingTransport{Transport: ct, Writer: &logbuf} s := NewServer(testImpl, nil) ss, err := s.Connect(ctx, st, nil) diff --git a/mcp/sse.go b/mcp/sse.go index f74a3fb6..b7f0d4e2 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -74,47 +74,66 @@ func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler { // A SSEServerTransport is a logical SSE session created through a hanging GET // request. // +// Use [SSEServerTransport.Connect] to initiate the flow of messages. +// // When connected, it returns the following [Connection] implementation: // - Writes are SSE 'message' events to the GET response. // - Reads are received from POSTs to the session endpoint, via // [SSEServerTransport.ServeHTTP]. // - Close terminates the hanging GET. +// +// The transport is itself an [http.Handler]. It is the caller's responsibility +// to ensure that the resulting transport serves HTTP requests on the given +// session endpoint. +// +// Each SSEServerTransport may be connected (via [Server.Connect]) at most +// once, since [SSEServerTransport.ServeHTTP] serves messages to the connected +// session. +// +// Most callers should instead use an [SSEHandler], which transparently handles +// the delegation to SSEServerTransports. type SSEServerTransport struct { - endpoint string - incoming chan jsonrpc.Message // queue of incoming messages; never closed + // Endpoint is the endpoint for this session, where the client can POST + // messages. + Endpoint string + + // Response is the hanging response body to the incoming GET request. + Response http.ResponseWriter + + // incoming is the queue of incoming messages. + // It is never closed, and by convention, incoming is non-nil if and only if + // the transport is connected. + incoming chan jsonrpc.Message // We must guard both pushes to the incoming queue and writes to the response // writer, because incoming POST requests are arbitrarily concurrent and we // need to ensure we don't write push to the queue, or write to the // ResponseWriter, after the session GET request exits. - mu sync.Mutex - w http.ResponseWriter // the hanging response body - closed bool // set when the stream is closed - done chan struct{} // closed when the connection is closed + mu sync.Mutex // also guards writes to Response + closed bool // set when the stream is closed + done chan struct{} // closed when the connection is closed } // NewSSEServerTransport creates a new SSE transport for the given messages // endpoint, and hanging GET response. // -// Use [SSEServerTransport.Connect] to initiate the flow of messages. -// -// The transport is itself an [http.Handler]. It is the caller's responsibility -// to ensure that the resulting transport serves HTTP requests on the given -// session endpoint. +// Deprecated: use an SSEServerTransport literal. // -// Most callers should instead use an [SSEHandler], which transparently handles -// the delegation to SSEServerTransports. +//go:fix inline func NewSSEServerTransport(endpoint string, w http.ResponseWriter) *SSEServerTransport { return &SSEServerTransport{ - endpoint: endpoint, - w: w, - incoming: make(chan jsonrpc.Message, 100), - done: make(chan struct{}), + Endpoint: endpoint, + Response: w, } } // ServeHTTP handles POST requests to the transport endpoint. func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if t.incoming == nil { + http.Error(w, "session not connected", http.StatusInternalServerError) + return + } + // Read and parse the message. data, err := io.ReadAll(req.Body) if err != nil { @@ -146,12 +165,15 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) // Connect sends the 'endpoint' event to the client. // See [SSEServerTransport] for more details on the [Connection] implementation. func (t *SSEServerTransport) Connect(context.Context) (Connection, error) { - t.mu.Lock() - _, err := writeEvent(t.w, Event{ + if t.incoming != nil { + return nil, fmt.Errorf("already connected") + } + t.incoming = make(chan jsonrpc.Message, 100) + t.done = make(chan struct{}) + _, err := writeEvent(t.Response, Event{ Name: "endpoint", - Data: []byte(t.endpoint), + Data: []byte(t.Endpoint), }) - t.mu.Unlock() if err != nil { return nil, err } @@ -203,7 +225,7 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - transport := NewSSEServerTransport(endpoint.RequestURI(), w) + transport := &SSEServerTransport{Endpoint: endpoint.RequestURI(), Response: w} // The session is terminated when the request exits. h.mu.Lock() @@ -279,7 +301,7 @@ func (s *sseServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { return io.EOF } - _, err = writeEvent(s.t.w, Event{Name: "message", Data: data}) + _, err = writeEvent(s.t.Response, Event{Name: "message", Data: data}) return err } @@ -304,12 +326,18 @@ func (s *sseServerConn) Close() error { // // https://modelcontextprotocol.io/specification/2024-11-05/basic/transports type SSEClientTransport struct { - sseEndpoint *url.URL - opts SSEClientTransportOptions + // Endpoint is the SSE endpoint to connect to. + Endpoint string + + // HTTPClient is the client to use for making HTTP requests. If nil, + // http.DefaultClient is used. + HTTPClient *http.Client } // SSEClientTransportOptions provides options for the [NewSSEClientTransport] // constructor. +// +// Deprecated: use an SSEClientTransport literal. type SSEClientTransportOptions struct { // HTTPClient is the client to use for making HTTP requests. If nil, // http.DefaultClient is used. @@ -319,28 +347,28 @@ type SSEClientTransportOptions struct { // NewSSEClientTransport returns a new client transport that connects to the // SSE server at the provided URL. // -// NewSSEClientTransport panics if the given URL is invalid. -func NewSSEClientTransport(baseURL string, opts *SSEClientTransportOptions) *SSEClientTransport { - url, err := url.Parse(baseURL) - if err != nil { - panic(fmt.Sprintf("invalid base url: %v", err)) - } - t := &SSEClientTransport{ - sseEndpoint: url, - } +// Deprecated: use an SSEClientTransport literal. +// +//go:fix inline +func NewSSEClientTransport(endpoint string, opts *SSEClientTransportOptions) *SSEClientTransport { + t := &SSEClientTransport{Endpoint: endpoint} if opts != nil { - t.opts = *opts + t.HTTPClient = opts.HTTPClient } return t } // Connect connects through the client endpoint. func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { - req, err := http.NewRequestWithContext(ctx, "GET", c.sseEndpoint.String(), nil) + parsedURL, err := url.Parse(c.Endpoint) + if err != nil { + return nil, fmt.Errorf("invalid endpoint: %v", err) + } + req, err := http.NewRequestWithContext(ctx, "GET", c.Endpoint, nil) if err != nil { return nil, err } - httpClient := c.opts.HTTPClient + httpClient := c.HTTPClient if httpClient == nil { httpClient = http.DefaultClient } @@ -362,7 +390,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { return nil, fmt.Errorf("first event is %q, want %q", evt.Name, "endpoint") } raw := string(evt.Data) - return c.sseEndpoint.Parse(raw) + return parsedURL.Parse(raw) }() if err != nil { resp.Body.Close() @@ -372,7 +400,6 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { // From here on, the stream takes ownership of resp.Body. s := &sseClientConn{ client: httpClient, - sseEndpoint: c.sseEndpoint, msgEndpoint: msgEndpoint, incoming: make(chan []byte, 100), body: resp.Body, @@ -404,7 +431,6 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { // - Close terminates the GET request. type sseClientConn struct { client *http.Client // HTTP client to use for requests - sseEndpoint *url.URL // SSE endpoint for the GET msgEndpoint *url.URL // session endpoint for POSTs incoming chan []byte // queue of incoming messages diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index 9a7c8ae7..cf1e75dc 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -35,7 +35,7 @@ func ExampleSSEHandler() { defer httpServer.Close() ctx := context.Background() - transport := mcp.NewSSEClientTransport(httpServer.URL, nil) + transport := &mcp.SSEClientTransport{Endpoint: httpServer.URL} client := mcp.NewClient(&mcp.Implementation{Name: "test", Version: "v1.0.0"}, nil) cs, err := client.Connect(ctx, transport, nil) if err != nil { diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 79cfacc3..408e92ec 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -44,9 +44,10 @@ func TestSSEServer(t *testing.T) { }), } - clientTransport := NewSSEClientTransport(httpServer.URL, &SSEClientTransportOptions{ + clientTransport := &SSEClientTransport{ + Endpoint: httpServer.URL, HTTPClient: customClient, - }) + } c := NewClient(testImpl, nil) cs, err := c.Connect(ctx, clientTransport, nil) diff --git a/mcp/streamable.go b/mcp/streamable.go index 72ac7e83..048b99aa 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -38,8 +38,8 @@ type StreamableHTTPHandler struct { getServer func(*http.Request) *Server opts StreamableHTTPOptions - sessionsMu sync.Mutex - sessions map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) + mu sync.Mutex + transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) } // StreamableHTTPOptions configures the StreamableHTTPHandler. @@ -50,12 +50,10 @@ type StreamableHTTPOptions struct { // meaning it is not persisted and no session validation is performed. GetSessionID func() string - // TODO: support configurable session ID generation (?) // TODO: support session retention (?) - // transportOptions sets the streamable server transport options to use when - // establishing a new session. - transportOptions *StreamableServerTransportOptions + // jsonResponse is forwarded to StreamableServerTransport.jsonResponse. + jsonResponse bool } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -65,8 +63,8 @@ type StreamableHTTPOptions struct { // If getServer returns nil, a 400 Bad Request will be served. func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *StreamableHTTPOptions) *StreamableHTTPHandler { h := &StreamableHTTPHandler{ - getServer: getServer, - sessions: make(map[string]*StreamableServerTransport), + getServer: getServer, + transports: make(map[string]*StreamableServerTransport), } if opts != nil { h.opts = *opts @@ -85,12 +83,12 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea // Should we allow passing in a session store? That would allow the handler to // be stateless. func (h *StreamableHTTPHandler) closeAll() { - h.sessionsMu.Lock() - defer h.sessionsMu.Unlock() - for _, s := range h.sessions { - s.Close() + h.mu.Lock() + defer h.mu.Unlock() + for _, s := range h.transports { + s.connection.Close() } - h.sessions = nil + h.transports = nil } func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -119,9 +117,9 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque var session *StreamableServerTransport if id := req.Header.Get(sessionIDHeader); id != "" { - h.sessionsMu.Lock() - session, _ = h.sessions[id] - h.sessionsMu.Unlock() + h.mu.Lock() + session, _ = h.transports[id] + h.mu.Unlock() if session == nil { http.Error(w, "session not found", http.StatusNotFound) return @@ -136,10 +134,10 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) return } - h.sessionsMu.Lock() - delete(h.sessions, session.sessionID) - h.sessionsMu.Unlock() - session.Close() + h.mu.Lock() + delete(h.transports, session.SessionID) + h.mu.Unlock() + session.connection.Close() w.WriteHeader(http.StatusNoContent) return } @@ -164,7 +162,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } sessionID := h.opts.GetSessionID() - s := NewStreamableServerTransport(sessionID, h.opts.transportOptions) + s := &StreamableServerTransport{SessionID: sessionID, jsonResponse: h.opts.jsonResponse} // To support stateless mode, we initialize the session with a default // state, so that it doesn't reject subsequent requests. @@ -189,9 +187,9 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // Stateless mode: close the session when the request exits. defer ss.Close() // close the fake session after handling the request } else { - h.sessionsMu.Lock() - h.sessions[s.sessionID] = s - h.sessionsMu.Unlock() + h.mu.Lock() + h.transports[s.SessionID] = s + h.mu.Unlock() } session = s } @@ -199,10 +197,34 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque session.ServeHTTP(w, req) } +// StreamableServerTransportOptions configures the stramable server transport. +// +// Deprecated: use a StreamableServerTransport literal. type StreamableServerTransportOptions struct { // Storage for events, to enable stream resumption. // If nil, a [MemoryEventStore] with the default maximum size will be used. EventStore EventStore +} + +// A StreamableServerTransport implements the server side of the MCP streamable +// transport. +// +// Each StreamableServerTransport may be connected (via [Server.Connect]) at +// most once, since [StreamableServerTransport.ServeHTTP] serves messages to +// the connected session. +type StreamableServerTransport struct { + // SessionID is the ID of this session. + // + // If SessionID is the empty string, this is a 'stateless' session, which has + // limited ability to communicate with the client. Otherwise, the session ID + // must be globally unique, that is, different from any other session ID + // anywhere, past and future. (We recommend using a crypto random number + // generator to produce one, as with [crypto/rand.Text].) + SessionID string + + // Storage for events, to enable stream resumption. + // If nil, a [MemoryEventStore] with the default maximum size will be used. + EventStore EventStore // jsonResponse, if set, tells the server to prefer to respond to requests // using application/json responses rather than text/event-stream. @@ -212,22 +234,36 @@ type StreamableServerTransportOptions struct { // requests made within the context of a server request will be sent to the // hanging GET request, if any. jsonResponse bool + + // connection is non-nil if and only if the transport has been connected. + connection *streamableServerConn } // NewStreamableServerTransport returns a new [StreamableServerTransport] with // the given session ID and options. -// The session ID must be globally unique, that is, different from any other -// session ID anywhere, past and future. (We recommend using a crypto random number -// generator to produce one, as with [crypto/rand.Text].) // -// A StreamableServerTransport implements the server-side of the streamable -// transport. +// Deprecated: use a StreamableServerTransport literal. +// +//go:fix inline. func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransportOptions) *StreamableServerTransport { - if opts == nil { - opts = &StreamableServerTransportOptions{} - } t := &StreamableServerTransport{ - sessionID: sessionID, + SessionID: sessionID, + } + if opts != nil { + t.EventStore = opts.EventStore + } + return t +} + +// Connect implements the [Transport] interface. +func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) { + if t.connection != nil { + return nil, fmt.Errorf("transport already connected") + } + t.connection = &streamableServerConn{ + sessionID: t.SessionID, + eventStore: t.EventStore, + jsonResponse: t.jsonResponse, incoming: make(chan jsonrpc.Message, 10), done: make(chan struct{}), streams: make(map[StreamID]*stream), @@ -237,29 +273,23 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp // // It is always text/event-stream, since it must carry arbitrarily many // messages. - t.streams[0] = newStream(0, false) - if opts != nil { - t.opts = *opts + t.connection.streams[0] = newStream(0, false) + if t.connection.eventStore == nil { + t.connection.eventStore = NewMemoryEventStore(nil) } - if t.opts.EventStore == nil { - t.opts.EventStore = NewMemoryEventStore(nil) - } - return t + return t.connection, nil } -func (t *StreamableServerTransport) SessionID() string { - return t.sessionID -} +type streamableServerConn struct { + sessionID string + jsonResponse bool + eventStore EventStore -// A StreamableServerTransport implements the [Transport] interface for a -// single session. -type StreamableServerTransport struct { lastStreamID atomic.Int64 // last stream ID used, atomically incremented - sessionID string - opts StreamableServerTransportOptions - incoming chan jsonrpc.Message // messages from the client to the server - done chan struct{} + opts StreamableServerTransportOptions + incoming chan jsonrpc.Message // messages from the client to the server + done chan struct{} mu sync.Mutex // Sessions are closed exactly once. @@ -289,6 +319,10 @@ type StreamableServerTransport struct { requestStreams map[jsonrpc.ID]StreamID } +func (c *streamableServerConn) SessionID() string { + return c.sessionID +} + // A stream is a single logical stream of SSE events within a server session. // A stream begins with a client request, or with a client GET that has // no Last-Event-ID header. @@ -349,13 +383,6 @@ func signalChanPtr() *chan struct{} { // [ServerSession]. type StreamID int64 -// Connect implements the [Transport] interface. -// -// TODO(rfindley): Connect should return a new object. (Why?) -func (s *StreamableServerTransport) Connect(context.Context) (Connection, error) { - return s, nil -} - // We track the incoming request ID inside the handler context using // idContextValue, so that notifications and server->client calls that occur in // the course of handling incoming requests are correlated with the incoming @@ -381,15 +408,20 @@ type idContextKey struct{} // ServeHTTP handles a single HTTP request for the session. func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if t.connection == nil { + http.Error(w, "transport not connected", http.StatusInternalServerError) + return + } switch req.Method { case http.MethodGet: - t.serveGET(w, req) + t.connection.serveGET(w, req) case http.MethodPost: - t.servePOST(w, req) + t.connection.servePOST(w, req) default: // Should not be reached, as this is checked in StreamableHTTPHandler.ServeHTTP. w.Header().Set("Allow", "GET, POST") http.Error(w, "unsupported method", http.StatusMethodNotAllowed) + return } } @@ -397,7 +429,7 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R // message parsed from the Last-Event-ID header. // // It returns an HTTP status code and error message. -func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) { +func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request) { // connID 0 corresponds to the default GET request. id := StreamID(0) // By default, we haven't seen a last index. Since indices start at 0, we represent @@ -414,9 +446,9 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re } } - t.mu.Lock() - stream, ok := t.streams[id] - t.mu.Unlock() + c.mu.Lock() + stream, ok := c.streams[id] + c.mu.Unlock() if !ok { http.Error(w, "unknown stream", http.StatusBadRequest) return @@ -428,7 +460,7 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re } defer stream.signal.Store(nil) persistent := id == 0 // Only the special stream 0 is a hanging get. - t.respondSSE(stream, w, req, lastIdx, persistent) + c.respondSSE(stream, w, req, lastIdx, persistent) } // servePOST handles an incoming message, and replies with either an outgoing @@ -436,7 +468,7 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re // jsonResponse option is set. // // It returns an HTTP status code and error message. -func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) { +func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Request) { if len(req.Header.Values("Last-Event-ID")) > 0 { http.Error(w, "can't send Last-Event-ID for POST request", http.StatusBadRequest) return @@ -485,20 +517,20 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R // notifications or server->client requests made in the course of handling. // Update accounting for this incoming payload. if len(requests) > 0 { - stream = newStream(StreamID(t.lastStreamID.Add(1)), t.opts.jsonResponse) - t.mu.Lock() - t.streams[stream.id] = stream + stream = newStream(StreamID(c.lastStreamID.Add(1)), c.jsonResponse) + c.mu.Lock() + c.streams[stream.id] = stream stream.requests = requests for reqID := range requests { - t.requestStreams[reqID] = stream.id + c.requestStreams[reqID] = stream.id } - t.mu.Unlock() + c.mu.Unlock() stream.signal.Store(signalChanPtr()) } // Publish incoming messages. for _, msg := range incoming { - t.incoming <- msg + c.incoming <- msg } if stream == nil { @@ -507,22 +539,22 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R } if stream.jsonResponse { - t.respondJSON(stream, w, req) + c.respondJSON(stream, w, req) } else { - t.respondSSE(stream, w, req, -1, false) + c.respondSSE(stream, w, req, -1, false) } } -func (t *StreamableServerTransport) respondJSON(stream *stream, w http.ResponseWriter, req *http.Request) { +func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter, req *http.Request) { w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Content-Type", "application/json") - if t.sessionID != "" { - w.Header().Set(sessionIDHeader, t.sessionID) + if c.sessionID != "" { + w.Header().Set(sessionIDHeader, c.sessionID) } var msgs []json.RawMessage ctx := req.Context() - for msg, ok := range t.messages(ctx, stream, false) { + for msg, ok := range c.messages(ctx, stream, false) { if !ok { if ctx.Err() != nil { w.WriteHeader(http.StatusNoContent) @@ -550,15 +582,15 @@ func (t *StreamableServerTransport) respondJSON(stream *stream, w http.ResponseW } // lastIndex is the index of the last seen event if resuming, else -1. -func (t *StreamableServerTransport) respondSSE(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int, persistent bool) { +func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int, persistent bool) { writes := 0 // Accept checked in [StreamableHTTPHandler] w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] w.Header().Set("Connection", "keep-alive") - if t.sessionID != "" { - w.Header().Set(sessionIDHeader, t.sessionID) + if c.sessionID != "" { + w.Header().Set(sessionIDHeader, c.sessionID) } // write one event containing data. @@ -588,7 +620,7 @@ func (t *StreamableServerTransport) respondSSE(stream *stream, w http.ResponseWr if lastIndex >= 0 { // Resume. - for data, err := range t.opts.EventStore.After(req.Context(), t.SessionID(), stream.id, lastIndex) { + for data, err := range c.eventStore.After(req.Context(), c.SessionID(), stream.id, lastIndex) { if err != nil { // TODO: reevaluate these status codes. // Maybe distinguish between storage errors, which are 500s, and missing @@ -610,7 +642,7 @@ func (t *StreamableServerTransport) respondSSE(stream *stream, w http.ResponseWr // Repeatedly collect pending outgoing events and send them. ctx := req.Context() - for msg, ok := range t.messages(ctx, stream, persistent) { + for msg, ok := range c.messages(ctx, stream, persistent) { if !ok { if ctx.Err() != nil && writes == 0 { // This probably doesn't matter, but respond with NoContent if the client disconnected. @@ -620,7 +652,7 @@ func (t *StreamableServerTransport) respondSSE(stream *stream, w http.ResponseWr } return } - if err := t.opts.EventStore.Append(req.Context(), t.SessionID(), stream.id, msg); err != nil { + if err := c.eventStore.Append(req.Context(), c.SessionID(), stream.id, msg); err != nil { errorf(http.StatusInternalServerError, "storing event: %v", err.Error()) return } @@ -638,14 +670,14 @@ func (t *StreamableServerTransport) respondSSE(stream *stream, w http.ResponseWr // If the stream did not terminate normally, it is either because ctx was // cancelled, or the connection is closed: check the ctx.Err() to differentiate // these cases. -func (t *StreamableServerTransport) messages(ctx context.Context, stream *stream, persistent bool) iter.Seq2[json.RawMessage, bool] { +func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool) iter.Seq2[json.RawMessage, bool] { return func(yield func(json.RawMessage, bool) bool) { for { - t.mu.Lock() + c.mu.Lock() outgoing := stream.outgoing stream.outgoing = nil nOutstanding := len(stream.requests) - t.mu.Unlock() + c.mu.Unlock() for _, data := range outgoing { if !yield(data, true) { @@ -665,7 +697,7 @@ func (t *StreamableServerTransport) messages(ctx context.Context, stream *stream select { case <-*stream.signal.Load(): // there are new outgoing messages // return to top of loop - case <-t.done: // session is closed + case <-c.done: // session is closed yield(nil, false) return case <-ctx.Done(): @@ -708,22 +740,22 @@ func parseEventID(eventID string) (sid StreamID, idx int, ok bool) { } // Read implements the [Connection] interface. -func (t *StreamableServerTransport) Read(ctx context.Context) (jsonrpc.Message, error) { +func (c *streamableServerConn) Read(ctx context.Context) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() - case msg, ok := <-t.incoming: + case msg, ok := <-c.incoming: if !ok { return nil, io.EOF } return msg, nil - case <-t.done: + case <-c.done: return nil, io.EOF } } // Write implements the [Connection] interface. -func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Message) error { +func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { // Find the incoming request that this write relates to, if any. var forRequest jsonrpc.ID isResponse := false @@ -746,9 +778,9 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa // connection 0. var forStream StreamID if forRequest.IsValid() { - t.mu.Lock() - forStream = t.requestStreams[forRequest] - t.mu.Unlock() + c.mu.Lock() + forStream = c.requestStreams[forRequest] + c.mu.Unlock() } data, err := jsonrpc2.EncodeMessage(msg) @@ -756,13 +788,13 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa return err } - t.mu.Lock() - defer t.mu.Unlock() - if t.isDone { + c.mu.Lock() + defer c.mu.Unlock() + if c.isDone { return errors.New("session is closed") } - stream := t.streams[forStream] + stream := c.streams[forStream] if stream == nil { return fmt.Errorf("no stream with ID %d", forStream) } @@ -775,7 +807,7 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa // TODO(rfindley): either of these, particularly the first, might be // considered a bug in the server. Report it through a side-channel? if len(stream.requests) == 0 && forStream != 0 || stream.jsonResponse && !isResponse { - stream = t.streams[0] + stream = c.streams[0] } // TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == 0 @@ -798,15 +830,15 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa } // Close implements the [Connection] interface. -func (t *StreamableServerTransport) Close() error { - t.mu.Lock() - defer t.mu.Unlock() - if !t.isDone { - t.isDone = true - close(t.done) +func (c *streamableServerConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.isDone { + c.isDone = true + close(c.done) // TODO: find a way to plumb a context here, or an event store with a long-running // close operation can take arbitrary time. Alternative: impose a fixed timeout here. - return t.opts.EventStore.SessionClosed(context.TODO(), t.sessionID) + return c.eventStore.SessionClosed(context.TODO(), c.sessionID) } return nil } @@ -815,8 +847,9 @@ func (t *StreamableServerTransport) Close() error { // endpoint serving the streamable HTTP transport defined by the 2025-03-26 // version of the spec. type StreamableClientTransport struct { - url string - opts StreamableClientTransportOptions + Endpoint string + HTTPClient *http.Client + ReconnectOptions *StreamableReconnectOptions } // StreamableReconnectOptions defines parameters for client reconnect attempts. @@ -847,6 +880,8 @@ var DefaultReconnectOptions = &StreamableReconnectOptions{ // StreamableClientTransportOptions provides options for the // [NewStreamableClientTransport] constructor. +// +// Deprecated: use a StremableClientTransport literal. type StreamableClientTransportOptions struct { // HTTPClient is the client to use for making HTTP requests. If nil, // http.DefaultClient is used. @@ -856,10 +891,15 @@ type StreamableClientTransportOptions struct { // NewStreamableClientTransport returns a new client transport that connects to // the streamable HTTP server at the provided URL. +// +// Deprecated: use a StreamableClientTransport literal. +// +//go:fix inline func NewStreamableClientTransport(url string, opts *StreamableClientTransportOptions) *StreamableClientTransport { - t := &StreamableClientTransport{url: url} + t := &StreamableClientTransport{Endpoint: url} if opts != nil { - t.opts = *opts + t.HTTPClient = opts.HTTPClient + t.ReconnectOptions = opts.ReconnectOptions } return t } @@ -873,11 +913,11 @@ func NewStreamableClientTransport(url string, opts *StreamableClientTransportOpt // When closed, the connection issues a DELETE request to terminate the logical // session. func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, error) { - client := t.opts.HTTPClient + client := t.HTTPClient if client == nil { client = http.DefaultClient } - reconnOpts := t.opts.ReconnectOptions + reconnOpts := t.ReconnectOptions if reconnOpts == nil { reconnOpts = DefaultReconnectOptions } @@ -886,7 +926,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er // cancelling its blocking network operations, which prevents hangs on exit. connCtx, cancel := context.WithCancel(context.Background()) conn := &streamableClientConn{ - url: t.url, + url: t.Endpoint, client: client, incoming: make(chan []byte, 100), done: make(chan struct{}), diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 54803939..ba7af9fa 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -44,7 +44,7 @@ func TestStreamableTransports(t *testing.T) { // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a // cookie-checking middleware. handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ - transportOptions: &StreamableServerTransportOptions{jsonResponse: useJSON}, + jsonResponse: useJSON, }) var ( @@ -77,9 +77,10 @@ func TestStreamableTransports(t *testing.T) { } jar.SetCookies(u, []*http.Cookie{{Name: "test-cookie", Value: "test-value"}}) httpClient := &http.Client{Jar: jar} - transport := NewStreamableClientTransport(httpServer.URL, &StreamableClientTransportOptions{ + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, HTTPClient: httpClient, - }) + } client := NewClient(testImpl, nil) session, err := client.Connect(ctx, transport, nil) if err != nil { @@ -173,7 +174,7 @@ func TestClientReplay(t *testing.T) { notifications <- params.Message }, }) - clientSession, err := client.Connect(ctx, NewStreamableClientTransport(proxy.URL, nil), nil) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: proxy.URL}, nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) } @@ -239,7 +240,7 @@ func TestServerInitiatedSSE(t *testing.T) { notifications <- "toolListChanged" }, }) - clientSession, err := client.Connect(ctx, NewStreamableClientTransport(httpServer.URL, nil), nil) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) } @@ -765,7 +766,7 @@ func TestStreamableClientTransportApplicationJSON(t *testing.T) { httpServer := httptest.NewServer(http.HandlerFunc(serverHandler)) defer httpServer.Close() - transport := NewStreamableClientTransport(httpServer.URL, nil) + transport := &StreamableClientTransport{Endpoint: httpServer.URL} client := NewClient(testImpl, nil) session, err := client.Connect(ctx, transport, nil) if err != nil { diff --git a/mcp/transport.go b/mcp/transport.go index f45d7d2b..76b79986 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -87,36 +87,39 @@ type serverConnection interface { // A StdioTransport is a [Transport] that communicates over stdin/stdout using // newline-delimited JSON. type StdioTransport struct { - ioTransport } -// An ioTransport is a [Transport] that communicates using newline-delimited -// JSON over an io.ReadWriteCloser. -type ioTransport struct { - rwc io.ReadWriteCloser -} - -func (t *ioTransport) Connect(context.Context) (Connection, error) { - return newIOConn(t.rwc), nil +// Connect implements the [Transport] interface. +func (*StdioTransport) Connect(context.Context) (Connection, error) { + return newIOConn(rwc{os.Stdin, os.Stdout}), nil } // NewStdioTransport constructs a transport that communicates over // stdin/stdout. +// +// Deprecated: use a StdioTransport literal. +// +//go:fix inline func NewStdioTransport() *StdioTransport { - return &StdioTransport{ioTransport{rwc{os.Stdin, os.Stdout}}} + return &StdioTransport{} } // An InMemoryTransport is a [Transport] that communicates over an in-memory // network connection, using newline-delimited JSON. type InMemoryTransport struct { - ioTransport + rwc io.ReadWriteCloser } -// NewInMemoryTransports returns two InMemoryTransports that connect to each +// Connect implements the [Transport] interface. +func (t *InMemoryTransport) Connect(context.Context) (Connection, error) { + return newIOConn(t.rwc), nil +} + +// NewInMemoryTransports returns two [InMemoryTransports] that connect to each // other. func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) { c1, c2 := net.Pipe() - return &InMemoryTransport{ioTransport{c1}}, &InMemoryTransport{ioTransport{c2}} + return &InMemoryTransport{c1}, &InMemoryTransport{c2} } type binder[T handler, State any] interface { @@ -208,24 +211,28 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params // A LoggingTransport is a [Transport] that delegates to another transport, // writing RPC logs to an io.Writer. type LoggingTransport struct { - delegate Transport - w io.Writer + Transport Transport + Writer io.Writer } // NewLoggingTransport creates a new LoggingTransport that delegates to the // provided transport, writing RPC logs to the provided io.Writer. +// +// Deprecated: use a LoggingTransport literal. +// +//go:fix inline func NewLoggingTransport(delegate Transport, w io.Writer) *LoggingTransport { - return &LoggingTransport{delegate, w} + return &LoggingTransport{Transport: delegate, Writer: w} } // Connect connects the underlying transport, returning a [Connection] that writes // logs to the configured destination. func (t *LoggingTransport) Connect(ctx context.Context) (Connection, error) { - delegate, err := t.delegate.Connect(ctx) + delegate, err := t.Transport.Connect(ctx) if err != nil { return nil, err } - return &loggingConn{delegate, t.w}, nil + return &loggingConn{delegate, t.Writer}, nil } type loggingConn struct { From 6e03217c831b1d142ed8a29c5d5916f47884008b Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 12 Aug 2025 17:49:50 -0400 Subject: [PATCH 093/221] mcp: introduce Requests (#267) Handlers and client/server methods take a single request argument combining the session and params. This slightly simplifies the signature for handlers, since there is no longer an often-ignored session argument. But more importantly, it opens the door to adding more information in requests, such as auth info and HTTP request headers. For #243. --- README.md | 4 +- design/design.md | 53 +++-- examples/server/completion/main.go | 6 +- examples/server/custom-transport/main.go | 4 +- examples/server/hello/main.go | 10 +- examples/server/memory/kb.go | 36 ++-- examples/server/memory/kb_test.go | 30 +-- examples/server/sequentialthinking/main.go | 22 +- .../server/sequentialthinking/main_test.go | 30 +-- examples/server/sse/main.go | 4 +- internal/readme/server/server.go | 4 +- mcp/client.go | 127 +++++++----- mcp/example_middleware_test.go | 20 +- mcp/example_progress_test.go | 8 +- mcp/mcp_test.go | 62 +++--- mcp/resource.go | 2 +- mcp/server.go | 194 ++++++++++-------- mcp/server_example_test.go | 4 +- mcp/server_test.go | 12 +- mcp/shared.go | 185 ++++++++++++----- mcp/shared_test.go | 19 +- mcp/sse_example_test.go | 4 +- mcp/streamable_test.go | 32 +-- mcp/tool.go | 23 ++- mcp/tool_test.go | 2 +- 25 files changed, 517 insertions(+), 380 deletions(-) diff --git a/README.md b/README.md index 22a3fed3..4700d087 100644 --- a/README.md +++ b/README.md @@ -115,9 +115,9 @@ type HiParams struct { Name string `json:"name" jsonschema:"the name of the person to greet"` } -func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[HiParams]) (*mcp.CallToolResultFor[any], error) { +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[HiParams]]) (*mcp.CallToolResultFor[any], error) { return &mcp.CallToolResultFor[any]{ - Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + params.Arguments.Name}}, + Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}}, }, nil } diff --git a/design/design.md b/design/design.md index 65a2b61f..fa2270e2 100644 --- a/design/design.md +++ b/design/design.md @@ -408,15 +408,12 @@ We provide a mechanism to add MCP-level middleware on the both the client and se ```go // A MethodHandler handles MCP messages. -// The params argument is an XXXParams struct pointer, such as *GetPromptParams. -// For methods, a MethodHandler must return either an XXResult struct pointer and a nil error, or -// nil with a non-nil error. -// For notifications, a MethodHandler must return nil, nil. -type MethodHandler[S Session] func( - ctx context.Context, _ *S, method string, params Params) (result Result, err error) +// For methods, exactly one of the return values must be nil. +// For notifications, both must be nil. +type MethodHandler func(ctx context.Context, method string, req Request) (result Result, err error) // Middleware is a function from MethodHandlers to MethodHandlers. -type Middleware[S Session] func(MethodHandler[S]) MethodHandler[S] +type Middleware func(MethodHandler) MethodHandler // AddMiddleware wraps the client/server's current method handler using the provided // middleware. Middleware is applied from right to left, so that the first one @@ -424,17 +421,17 @@ type Middleware[S Session] func(MethodHandler[S]) MethodHandler[S] // // For example, AddMiddleware(m1, m2, m3) augments the server method handler as // m1(m2(m3(handler))). -func (c *Client) AddSendingMiddleware(middleware ...Middleware[*ClientSession]) -func (c *Client) AddReceivingMiddleware(middleware ...Middleware[*ClientSession]) -func (s *Server) AddSendingMiddleware(middleware ...Middleware[*ServerSession]) -func (s *Server) AddReceivingMiddleware(middleware ...Middleware[*ServerSession]) +func (c *Client) AddSendingMiddleware(middleware ...Middleware) +func (c *Client) AddReceivingMiddleware(middleware ...Middleware) +func (s *Server) AddSendingMiddleware(middleware ...Middleware) +func (s *Server) AddReceivingMiddleware(middleware ...Middleware) ``` As an example, this code adds server-side logging: ```go -func withLogging(h mcp.MethodHandler[*mcp.ServerSession]) mcp.MethodHandler[*mcp.ServerSession]{ - return func(ctx context.Context, s *mcp.ServerSession, method string, params mcp.Params) (res mcp.Result, err error) { +func withLogging(h mcp.MethodHandler) mcp.MethodHandler{ + return func(ctx context.Context, method string, req mcp.Request) (res mcp.Result, err error) { log.Printf("request: %s %v", method, params) defer func() { log.Printf("response: %v, %v", res, err) }() return h(ctx, s , method, params) @@ -597,8 +594,9 @@ type Tool struct { Name string `json:"name"` } -type ToolHandlerFor[In, Out any] func(context.Context, *ServerSession, *CallToolParamsFor[In]) (*CallToolResultFor[Out], error) -type ToolHandler = ToolHandlerFor[map[string]any, any] +// A ToolHandlerFor handles a call to tools/call with typed arguments and results. +type ToolHandlerFor[In, Out any] func(context.Context, *ServerRequest[*CallToolParamsFor[In]]) (*CallToolResultFor[Out], error) + ``` Add tools to a server with the `AddTool` method or function. The function is generic and infers schemas from the handler @@ -648,8 +646,8 @@ type AddParams struct { Y int `json:"y"` } -func addHandler(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[AddParams]) (*mcp.CallToolResultFor[int], error) { - return &mcp.CallToolResultFor[int]{StructuredContent: params.Arguments.X + params.Arguments.Y}, nil +func addHandler(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[AddParams]]) (*mcp.CallToolResultFor[int], error) { + return &mcp.CallToolResultFor[int]{StructuredContent: req.Params.Arguments.X + req.Params.Arguments.Y}, nil } ``` @@ -665,8 +663,6 @@ Client sessions can call the spec method `ListTools` or an iterator method `Tool ```go func (cs *ClientSession) CallTool(context.Context, *CallToolParams[json.RawMessage]) (*CallToolResult, error) - -func CallTool[TArgs any](context.Context, *ClientSession, *CallToolParams[TArgs]) (*CallToolResult, error) ``` **Differences from mcp-go**: We provide a full JSON Schema implementation for validating tool input schemas against incoming arguments. The `jsonschema.Schema` type provides exported features for all keywords in the JSON Schema draft2020-12 spec. Tool definers can use it to construct any schema they want. The `jsonschema.For[T]` function can infer a schema from a Go struct. These combined features eliminate the need for variadic arguments to construct tool schemas. @@ -752,10 +748,10 @@ If a server author wants to support resource subscriptions, they must provide ha ```go type ServerOptions struct { ... - // Function called when a client session subscribes to a resource. - SubscribeHandler func(context.Context, ss *ServerSession, *SubscribeParams) error - // Function called when a client session unsubscribes from a resource. - UnsubscribeHandler func(context.Context, ss *ServerSession, *UnsubscribeParams) error + // Function called when a client session subscribes to a resource. + SubscribeHandler func(context.Context, *ServerRequest[*SubscribeParams]) error + // Function called when a client session unsubscribes from a resource. + UnsubscribeHandler func(context.Context, *ServerRequest[*UnsubscribeParams]) error } ``` @@ -774,10 +770,10 @@ When a list of tools, prompts or resources changes as the result of an AddXXX or ```go type ClientOptions struct { ... - ToolListChangedHandler func(context.Context, *ClientSession, *ToolListChangedParams) - PromptListChangedHandler func(context.Context, *ClientSession, *PromptListChangedParams) + ToolListChangedHandler func(context.Context, *ClientRequest[*ToolListChangedParams]) + PromptListChangedHandler func(context.Context, *ClientRequest[*PromptListChangedParams]) // For both resources and resource templates. - ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams) + ResourceListChangedHandler func(context.Context, *ClientRequest[*ResourceListChangedParams]) } ``` @@ -788,13 +784,10 @@ type ClientOptions struct { Clients call the spec method `Complete` to request completions. If a server installs a `CompletionHandler`, it will be called when the client sends a completion request. ```go -// A CompletionHandler handles a call to completion/complete. -type CompletionHandler func(context.Context, *ServerSession, *CompleteParams) (*CompleteResult, error) - type ServerOptions struct { ... // If non-nil, called when a client sends a completion request. - CompletionHandler CompletionHandler + CompletionHandler func(context.Context, *ServerRequest[*CompleteParams]) (*CompleteResult, error) } ``` diff --git a/examples/server/completion/main.go b/examples/server/completion/main.go index d6530b5d..f05b2721 100644 --- a/examples/server/completion/main.go +++ b/examples/server/completion/main.go @@ -16,17 +16,17 @@ import ( // a CompletionHandler to an MCP Server's options. func main() { // Define your custom CompletionHandler logic. - myCompletionHandler := func(_ context.Context, _ *mcp.ServerSession, params *mcp.CompleteParams) (*mcp.CompleteResult, error) { + myCompletionHandler := func(_ context.Context, req *mcp.ServerRequest[*mcp.CompleteParams]) (*mcp.CompleteResult, error) { // In a real application, you'd implement actual completion logic here. // For this example, we return a fixed set of suggestions. var suggestions []string - switch params.Ref.Type { + switch req.Params.Ref.Type { case "ref/prompt": suggestions = []string{"suggestion1", "suggestion2", "suggestion3"} case "ref/resource": suggestions = []string{"suggestion4", "suggestion5", "suggestion6"} default: - return nil, fmt.Errorf("unrecognized content type %s", params.Ref.Type) + return nil, fmt.Errorf("unrecognized content type %s", req.Params.Ref.Type) } return &mcp.CompleteResult{ diff --git a/examples/server/custom-transport/main.go b/examples/server/custom-transport/main.go index cc4f15f3..bf0306cf 100644 --- a/examples/server/custom-transport/main.go +++ b/examples/server/custom-transport/main.go @@ -85,10 +85,10 @@ type HiArgs struct { } // SayHi is a tool handler that responds with a greeting. -func SayHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[HiArgs]) (*mcp.CallToolResultFor[struct{}], error) { +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[HiArgs]]) (*mcp.CallToolResultFor[struct{}], error) { return &mcp.CallToolResultFor[struct{}]{ Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + params.Arguments.Name}, + &mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}, }, }, nil } diff --git a/examples/server/hello/main.go b/examples/server/hello/main.go index 00fb37a6..8125441b 100644 --- a/examples/server/hello/main.go +++ b/examples/server/hello/main.go @@ -22,10 +22,10 @@ type HiArgs struct { Name string `json:"name" jsonschema:"the name to say hi to"` } -func SayHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[HiArgs]) (*mcp.CallToolResultFor[struct{}], error) { +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[HiArgs]]) (*mcp.CallToolResultFor[struct{}], error) { return &mcp.CallToolResultFor[struct{}]{ Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + params.Arguments.Name}, + &mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}, }, }, nil } @@ -69,8 +69,8 @@ var embeddedResources = map[string]string{ "info": "This is the hello example server.", } -func handleEmbeddedResource(_ context.Context, _ *mcp.ServerSession, params *mcp.ReadResourceParams) (*mcp.ReadResourceResult, error) { - u, err := url.Parse(params.URI) +func handleEmbeddedResource(_ context.Context, req *mcp.ServerRequest[*mcp.ReadResourceParams]) (*mcp.ReadResourceResult, error) { + u, err := url.Parse(req.Params.URI) if err != nil { return nil, err } @@ -84,7 +84,7 @@ func handleEmbeddedResource(_ context.Context, _ *mcp.ServerSession, params *mcp } return &mcp.ReadResourceResult{ Contents: []*mcp.ResourceContents{ - {URI: params.URI, MIMEType: "text/plain", Text: text}, + {URI: req.Params.URI, MIMEType: "text/plain", Text: text}, }, }, nil } diff --git a/examples/server/memory/kb.go b/examples/server/memory/kb.go index a274f057..f053bee5 100644 --- a/examples/server/memory/kb.go +++ b/examples/server/memory/kb.go @@ -84,7 +84,7 @@ func (fs *fileStore) Read() ([]byte, error) { // Write saves data to file with 0600 permissions. func (fs *fileStore) Write(data []byte) error { - if err := os.WriteFile(fs.path, data, 0600); err != nil { + if err := os.WriteFile(fs.path, data, 0o600); err != nil { return fmt.Errorf("failed to write file %s: %w", fs.path, err) } return nil @@ -431,10 +431,10 @@ func (k knowledgeBase) openNodes(names []string) (KnowledgeGraph, error) { }, nil } -func (k knowledgeBase) CreateEntities(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[CreateEntitiesArgs]) (*mcp.CallToolResultFor[CreateEntitiesResult], error) { +func (k knowledgeBase) CreateEntities(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[CreateEntitiesArgs]]) (*mcp.CallToolResultFor[CreateEntitiesResult], error) { var res mcp.CallToolResultFor[CreateEntitiesResult] - entities, err := k.createEntities(params.Arguments.Entities) + entities, err := k.createEntities(req.Params.Arguments.Entities) if err != nil { return nil, err } @@ -450,10 +450,10 @@ func (k knowledgeBase) CreateEntities(ctx context.Context, ss *mcp.ServerSession return &res, nil } -func (k knowledgeBase) CreateRelations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[CreateRelationsArgs]) (*mcp.CallToolResultFor[CreateRelationsResult], error) { +func (k knowledgeBase) CreateRelations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[CreateRelationsArgs]]) (*mcp.CallToolResultFor[CreateRelationsResult], error) { var res mcp.CallToolResultFor[CreateRelationsResult] - relations, err := k.createRelations(params.Arguments.Relations) + relations, err := k.createRelations(req.Params.Arguments.Relations) if err != nil { return nil, err } @@ -469,10 +469,10 @@ func (k knowledgeBase) CreateRelations(ctx context.Context, ss *mcp.ServerSessio return &res, nil } -func (k knowledgeBase) AddObservations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[AddObservationsArgs]) (*mcp.CallToolResultFor[AddObservationsResult], error) { +func (k knowledgeBase) AddObservations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[AddObservationsArgs]]) (*mcp.CallToolResultFor[AddObservationsResult], error) { var res mcp.CallToolResultFor[AddObservationsResult] - observations, err := k.addObservations(params.Arguments.Observations) + observations, err := k.addObservations(req.Params.Arguments.Observations) if err != nil { return nil, err } @@ -488,10 +488,10 @@ func (k knowledgeBase) AddObservations(ctx context.Context, ss *mcp.ServerSessio return &res, nil } -func (k knowledgeBase) DeleteEntities(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[DeleteEntitiesArgs]) (*mcp.CallToolResultFor[struct{}], error) { +func (k knowledgeBase) DeleteEntities(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[DeleteEntitiesArgs]]) (*mcp.CallToolResultFor[struct{}], error) { var res mcp.CallToolResultFor[struct{}] - err := k.deleteEntities(params.Arguments.EntityNames) + err := k.deleteEntities(req.Params.Arguments.EntityNames) if err != nil { return nil, err } @@ -503,10 +503,10 @@ func (k knowledgeBase) DeleteEntities(ctx context.Context, ss *mcp.ServerSession return &res, nil } -func (k knowledgeBase) DeleteObservations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[DeleteObservationsArgs]) (*mcp.CallToolResultFor[struct{}], error) { +func (k knowledgeBase) DeleteObservations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[DeleteObservationsArgs]]) (*mcp.CallToolResultFor[struct{}], error) { var res mcp.CallToolResultFor[struct{}] - err := k.deleteObservations(params.Arguments.Deletions) + err := k.deleteObservations(req.Params.Arguments.Deletions) if err != nil { return nil, err } @@ -518,10 +518,10 @@ func (k knowledgeBase) DeleteObservations(ctx context.Context, ss *mcp.ServerSes return &res, nil } -func (k knowledgeBase) DeleteRelations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[DeleteRelationsArgs]) (*mcp.CallToolResultFor[struct{}], error) { +func (k knowledgeBase) DeleteRelations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[DeleteRelationsArgs]]) (*mcp.CallToolResultFor[struct{}], error) { var res mcp.CallToolResultFor[struct{}] - err := k.deleteRelations(params.Arguments.Relations) + err := k.deleteRelations(req.Params.Arguments.Relations) if err != nil { return nil, err } @@ -533,7 +533,7 @@ func (k knowledgeBase) DeleteRelations(ctx context.Context, ss *mcp.ServerSessio return &res, nil } -func (k knowledgeBase) ReadGraph(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[struct{}]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { +func (k knowledgeBase) ReadGraph(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[struct{}]]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { var res mcp.CallToolResultFor[KnowledgeGraph] graph, err := k.loadGraph() @@ -549,10 +549,10 @@ func (k knowledgeBase) ReadGraph(ctx context.Context, ss *mcp.ServerSession, par return &res, nil } -func (k knowledgeBase) SearchNodes(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[SearchNodesArgs]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { +func (k knowledgeBase) SearchNodes(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[SearchNodesArgs]]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { var res mcp.CallToolResultFor[KnowledgeGraph] - graph, err := k.searchNodes(params.Arguments.Query) + graph, err := k.searchNodes(req.Params.Arguments.Query) if err != nil { return nil, err } @@ -565,10 +565,10 @@ func (k knowledgeBase) SearchNodes(ctx context.Context, ss *mcp.ServerSession, p return &res, nil } -func (k knowledgeBase) OpenNodes(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[OpenNodesArgs]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { +func (k knowledgeBase) OpenNodes(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[OpenNodesArgs]]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { var res mcp.CallToolResultFor[KnowledgeGraph] - graph, err := k.openNodes(params.Arguments.Names) + graph, err := k.openNodes(req.Params.Arguments.Names) if err != nil { return nil, err } diff --git a/examples/server/memory/kb_test.go b/examples/server/memory/kb_test.go index e4fbacc9..6e29d5e4 100644 --- a/examples/server/memory/kb_test.go +++ b/examples/server/memory/kb_test.go @@ -230,7 +230,7 @@ func TestSaveAndLoadGraph(t *testing.T) { // Test malformed data handling if fs, ok := s.(*fileStore); ok { - err := os.WriteFile(fs.path, []byte("invalid json"), 0600) + err := os.WriteFile(fs.path, []byte("invalid json"), 0o600) if err != nil { t.Fatalf("failed to write invalid json: %v", err) } @@ -450,7 +450,7 @@ func TestMCPServerIntegration(t *testing.T) { }, } - createResult, err := kb.CreateEntities(ctx, serverSession, createEntitiesParams) + createResult, err := kb.CreateEntities(ctx, requestFor(serverSession, createEntitiesParams)) if err != nil { t.Fatalf("MCP CreateEntities failed: %v", err) } @@ -463,7 +463,7 @@ func TestMCPServerIntegration(t *testing.T) { // Test ReadGraph through MCP readParams := &mcp.CallToolParamsFor[struct{}]{} - readResult, err := kb.ReadGraph(ctx, serverSession, readParams) + readResult, err := kb.ReadGraph(ctx, requestFor(serverSession, readParams)) if err != nil { t.Fatalf("MCP ReadGraph failed: %v", err) } @@ -487,7 +487,7 @@ func TestMCPServerIntegration(t *testing.T) { }, } - relationsResult, err := kb.CreateRelations(ctx, serverSession, createRelationsParams) + relationsResult, err := kb.CreateRelations(ctx, requestFor(serverSession, createRelationsParams)) if err != nil { t.Fatalf("MCP CreateRelations failed: %v", err) } @@ -510,7 +510,7 @@ func TestMCPServerIntegration(t *testing.T) { }, } - obsResult, err := kb.AddObservations(ctx, serverSession, addObsParams) + obsResult, err := kb.AddObservations(ctx, requestFor(serverSession, addObsParams)) if err != nil { t.Fatalf("MCP AddObservations failed: %v", err) } @@ -528,7 +528,7 @@ func TestMCPServerIntegration(t *testing.T) { }, } - searchResult, err := kb.SearchNodes(ctx, serverSession, searchParams) + searchResult, err := kb.SearchNodes(ctx, requestFor(serverSession, searchParams)) if err != nil { t.Fatalf("MCP SearchNodes failed: %v", err) } @@ -546,7 +546,7 @@ func TestMCPServerIntegration(t *testing.T) { }, } - openResult, err := kb.OpenNodes(ctx, serverSession, openParams) + openResult, err := kb.OpenNodes(ctx, requestFor(serverSession, openParams)) if err != nil { t.Fatalf("MCP OpenNodes failed: %v", err) } @@ -569,7 +569,7 @@ func TestMCPServerIntegration(t *testing.T) { }, } - deleteObsResult, err := kb.DeleteObservations(ctx, serverSession, deleteObsParams) + deleteObsResult, err := kb.DeleteObservations(ctx, requestFor(serverSession, deleteObsParams)) if err != nil { t.Fatalf("MCP DeleteObservations failed: %v", err) } @@ -590,7 +590,7 @@ func TestMCPServerIntegration(t *testing.T) { }, } - deleteRelResult, err := kb.DeleteRelations(ctx, serverSession, deleteRelParams) + deleteRelResult, err := kb.DeleteRelations(ctx, requestFor(serverSession, deleteRelParams)) if err != nil { t.Fatalf("MCP DeleteRelations failed: %v", err) } @@ -605,7 +605,7 @@ func TestMCPServerIntegration(t *testing.T) { }, } - deleteEntResult, err := kb.DeleteEntities(ctx, serverSession, deleteEntParams) + deleteEntResult, err := kb.DeleteEntities(ctx, requestFor(serverSession, deleteEntParams)) if err != nil { t.Fatalf("MCP DeleteEntities failed: %v", err) } @@ -614,7 +614,7 @@ func TestMCPServerIntegration(t *testing.T) { } // Verify final state - finalRead, err := kb.ReadGraph(ctx, serverSession, readParams) + finalRead, err := kb.ReadGraph(ctx, requestFor(serverSession, readParams)) if err != nil { t.Fatalf("Final MCP ReadGraph failed: %v", err) } @@ -647,7 +647,7 @@ func TestMCPErrorHandling(t *testing.T) { }, } - _, err := kb.AddObservations(ctx, serverSession, addObsParams) + _, err := kb.AddObservations(ctx, requestFor(serverSession, addObsParams)) if err == nil { t.Errorf("expected MCP AddObservations to return error for non-existent entity") } else { @@ -678,7 +678,7 @@ func TestMCPResponseFormat(t *testing.T) { }, } - result, err := kb.CreateEntities(ctx, serverSession, createParams) + result, err := kb.CreateEntities(ctx, requestFor(serverSession, createParams)) if err != nil { t.Fatalf("CreateEntities failed: %v", err) } @@ -701,3 +701,7 @@ func TestMCPResponseFormat(t *testing.T) { t.Errorf("expected Content[0] to be TextContent") } } + +func requestFor[P mcp.Params](ss *mcp.ServerSession, p P) *mcp.ServerRequest[P] { + return &mcp.ServerRequest[P]{Session: ss, Params: p} +} diff --git a/examples/server/sequentialthinking/main.go b/examples/server/sequentialthinking/main.go index e9cb594d..45a4fa6f 100644 --- a/examples/server/sequentialthinking/main.go +++ b/examples/server/sequentialthinking/main.go @@ -231,8 +231,8 @@ func deepCopyThoughts(thoughts []*Thought) []*Thought { } // StartThinking begins a new sequential thinking session for a complex problem. -func StartThinking(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[StartThinkingArgs]) (*mcp.CallToolResultFor[any], error) { - args := params.Arguments +func StartThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[StartThinkingArgs]]) (*mcp.CallToolResultFor[any], error) { + args := req.Params.Arguments sessionID := args.SessionID if sessionID == "" { @@ -266,8 +266,8 @@ func StartThinking(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallT } // ContinueThinking adds the next thought step, revises a previous step, or creates a branch in the thinking process. -func ContinueThinking(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[ContinueThinkingArgs]) (*mcp.CallToolResultFor[any], error) { - args := params.Arguments +func ContinueThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[ContinueThinkingArgs]]) (*mcp.CallToolResultFor[any], error) { + args := req.Params.Arguments // Handle revision of existing thought if args.ReviseStep != nil { @@ -395,8 +395,8 @@ func ContinueThinking(ctx context.Context, ss *mcp.ServerSession, params *mcp.Ca } // ReviewThinking provides a complete review of the thinking process for a session. -func ReviewThinking(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[ReviewThinkingArgs]) (*mcp.CallToolResultFor[any], error) { - args := params.Arguments +func ReviewThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[ReviewThinkingArgs]]) (*mcp.CallToolResultFor[any], error) { + args := req.Params.Arguments // Get a snapshot of the session to avoid race conditions sessionSnapshot, exists := store.SessionSnapshot(args.SessionID) @@ -434,11 +434,11 @@ func ReviewThinking(ctx context.Context, ss *mcp.ServerSession, params *mcp.Call } // ThinkingHistory handles resource requests for thinking session data and history. -func ThinkingHistory(ctx context.Context, ss *mcp.ServerSession, params *mcp.ReadResourceParams) (*mcp.ReadResourceResult, error) { +func ThinkingHistory(ctx context.Context, req *mcp.ServerRequest[*mcp.ReadResourceParams]) (*mcp.ReadResourceResult, error) { // Extract session ID from URI (e.g., "thinking://session_123") - u, err := url.Parse(params.URI) + u, err := url.Parse(req.Params.URI) if err != nil { - return nil, fmt.Errorf("invalid thinking resource URI: %s", params.URI) + return nil, fmt.Errorf("invalid thinking resource URI: %s", req.Params.URI) } if u.Scheme != "thinking" { return nil, fmt.Errorf("invalid thinking resource URI scheme: %s", u.Scheme) @@ -456,7 +456,7 @@ func ThinkingHistory(ctx context.Context, ss *mcp.ServerSession, params *mcp.Rea return &mcp.ReadResourceResult{ Contents: []*mcp.ResourceContents{ { - URI: params.URI, + URI: req.Params.URI, MIMEType: "application/json", Text: string(data), }, @@ -478,7 +478,7 @@ func ThinkingHistory(ctx context.Context, ss *mcp.ServerSession, params *mcp.Rea return &mcp.ReadResourceResult{ Contents: []*mcp.ResourceContents{ { - URI: params.URI, + URI: req.Params.URI, MIMEType: "application/json", Text: string(data), }, diff --git a/examples/server/sequentialthinking/main_test.go b/examples/server/sequentialthinking/main_test.go index 13cd9ccb..c5e4a95a 100644 --- a/examples/server/sequentialthinking/main_test.go +++ b/examples/server/sequentialthinking/main_test.go @@ -31,7 +31,7 @@ func TestStartThinking(t *testing.T) { Arguments: args, } - result, err := StartThinking(ctx, nil, params) + result, err := StartThinking(ctx, requestFor(params)) if err != nil { t.Fatalf("StartThinking() error = %v", err) } @@ -89,7 +89,7 @@ func TestContinueThinking(t *testing.T) { Arguments: startArgs, } - _, err := StartThinking(ctx, nil, startParams) + _, err := StartThinking(ctx, requestFor(startParams)) if err != nil { t.Fatalf("StartThinking() error = %v", err) } @@ -105,7 +105,7 @@ func TestContinueThinking(t *testing.T) { Arguments: continueArgs, } - result, err := ContinueThinking(ctx, nil, continueParams) + result, err := ContinueThinking(ctx, requestFor(continueParams)) if err != nil { t.Fatalf("ContinueThinking() error = %v", err) } @@ -158,7 +158,7 @@ func TestContinueThinkingWithCompletion(t *testing.T) { Arguments: startArgs, } - _, err := StartThinking(ctx, nil, startParams) + _, err := StartThinking(ctx, requestFor(startParams)) if err != nil { t.Fatalf("StartThinking() error = %v", err) } @@ -176,7 +176,7 @@ func TestContinueThinkingWithCompletion(t *testing.T) { Arguments: continueArgs, } - result, err := ContinueThinking(ctx, nil, continueParams) + result, err := ContinueThinking(ctx, requestFor(continueParams)) if err != nil { t.Fatalf("ContinueThinking() error = %v", err) } @@ -233,7 +233,7 @@ func TestContinueThinkingRevision(t *testing.T) { Arguments: continueArgs, } - result, err := ContinueThinking(ctx, nil, continueParams) + result, err := ContinueThinking(ctx, requestFor(continueParams)) if err != nil { t.Fatalf("ContinueThinking() error = %v", err) } @@ -289,7 +289,7 @@ func TestContinueThinkingBranching(t *testing.T) { Arguments: continueArgs, } - result, err := ContinueThinking(ctx, nil, continueParams) + result, err := ContinueThinking(ctx, requestFor(continueParams)) if err != nil { t.Fatalf("ContinueThinking() error = %v", err) } @@ -356,7 +356,7 @@ func TestReviewThinking(t *testing.T) { Arguments: reviewArgs, } - result, err := ReviewThinking(ctx, nil, reviewParams) + result, err := ReviewThinking(ctx, requestFor(reviewParams)) if err != nil { t.Fatalf("ReviewThinking() error = %v", err) } @@ -431,7 +431,7 @@ func TestThinkingHistory(t *testing.T) { URI: "thinking://sessions", } - result, err := ThinkingHistory(ctx, nil, listParams) + result, err := ThinkingHistory(ctx, requestFor(listParams)) if err != nil { t.Fatalf("ThinkingHistory() error = %v", err) } @@ -461,7 +461,7 @@ func TestThinkingHistory(t *testing.T) { URI: "thinking://session1", } - result, err = ThinkingHistory(ctx, nil, sessionParams) + result, err = ThinkingHistory(ctx, requestFor(sessionParams)) if err != nil { t.Fatalf("ThinkingHistory() error = %v", err) } @@ -496,7 +496,7 @@ func TestInvalidOperations(t *testing.T) { Arguments: continueArgs, } - _, err := ContinueThinking(ctx, nil, continueParams) + _, err := ContinueThinking(ctx, requestFor(continueParams)) if err == nil { t.Error("Expected error for non-existent session") } @@ -511,7 +511,7 @@ func TestInvalidOperations(t *testing.T) { Arguments: reviewArgs, } - _, err = ReviewThinking(ctx, nil, reviewParams) + _, err = ReviewThinking(ctx, requestFor(reviewParams)) if err == nil { t.Error("Expected error for non-existent session in review") } @@ -541,8 +541,12 @@ func TestInvalidOperations(t *testing.T) { Arguments: invalidReviseArgs, } - _, err = ContinueThinking(ctx, nil, invalidReviseParams) + _, err = ContinueThinking(ctx, requestFor(invalidReviseParams)) if err == nil { t.Error("Expected error for invalid revision step") } } + +func requestFor[P mcp.Params](p P) *mcp.ServerRequest[P] { + return &mcp.ServerRequest[P]{Params: p} +} diff --git a/examples/server/sse/main.go b/examples/server/sse/main.go index 59412b15..2fbd695e 100644 --- a/examples/server/sse/main.go +++ b/examples/server/sse/main.go @@ -24,10 +24,10 @@ type SayHiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[SayHiParams]) (*mcp.CallToolResultFor[any], error) { +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[SayHiParams]]) (*mcp.CallToolResultFor[any], error) { return &mcp.CallToolResultFor[any]{ Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + params.Arguments.Name}, + &mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}, }, }, nil } diff --git a/internal/readme/server/server.go b/internal/readme/server/server.go index 3746e194..3aa1037c 100644 --- a/internal/readme/server/server.go +++ b/internal/readme/server/server.go @@ -16,9 +16,9 @@ type HiParams struct { Name string `json:"name" jsonschema:"the name of the person to greet"` } -func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[HiParams]) (*mcp.CallToolResultFor[any], error) { +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[HiParams]]) (*mcp.CallToolResultFor[any], error) { return &mcp.CallToolResultFor[any]{ - Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + params.Arguments.Name}}, + Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}}, }, nil } diff --git a/mcp/client.go b/mcp/client.go index 5798dc5a..b0db1d64 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -24,8 +24,8 @@ type Client struct { mu sync.Mutex roots *featureSet[*Root] sessions []*ClientSession - sendingMethodHandler_ MethodHandler[*ClientSession] - receivingMethodHandler_ MethodHandler[*ClientSession] + sendingMethodHandler_ MethodHandler + receivingMethodHandler_ MethodHandler } // NewClient creates a new [Client]. @@ -55,14 +55,14 @@ func NewClient(impl *Implementation, opts *ClientOptions) *Client { type ClientOptions struct { // Handler for sampling. // Called when a server calls CreateMessage. - CreateMessageHandler func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) + CreateMessageHandler func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) // Handlers for notifications from the server. - ToolListChangedHandler func(context.Context, *ClientSession, *ToolListChangedParams) - PromptListChangedHandler func(context.Context, *ClientSession, *PromptListChangedParams) - ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams) - ResourceUpdatedHandler func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams) - LoggingMessageHandler func(context.Context, *ClientSession, *LoggingMessageParams) - ProgressNotificationHandler func(context.Context, *ClientSession, *ProgressNotificationParams) + ToolListChangedHandler func(context.Context, *ClientRequest[*ToolListChangedParams]) + PromptListChangedHandler func(context.Context, *ClientRequest[*PromptListChangedParams]) + ResourceListChangedHandler func(context.Context, *ClientRequest[*ResourceListChangedParams]) + ResourceUpdatedHandler func(context.Context, *ClientRequest[*ResourceUpdatedNotificationParams]) + LoggingMessageHandler func(context.Context, *ClientRequest[*LoggingMessageParams]) + ProgressNotificationHandler func(context.Context, *ClientRequest[*ProgressNotificationParams]) // If non-zero, defines an interval for regular "ping" requests. // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. @@ -130,7 +130,8 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio ClientInfo: c.impl, Capabilities: caps, } - res, err := handleSend[*InitializeResult](ctx, cs, methodInitialize, params) + req := &ClientRequest[*InitializeParams]{Session: cs, Params: params} + res, err := handleSend[*InitializeResult](ctx, methodInitialize, req) if err != nil { _ = cs.Close() return nil, err @@ -142,7 +143,8 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio if hc, ok := cs.mcpConn.(clientConnection); ok { hc.sessionUpdated(cs.state) } - if err := handleNotify(ctx, cs, notificationInitialized, &InitializedParams{}); err != nil { + req2 := &ClientRequest[*InitializedParams]{Session: cs, Params: &InitializedParams{}} + if err := handleNotify(ctx, notificationInitialized, req2); err != nil { _ = cs.Close() return nil, err } @@ -216,7 +218,7 @@ func (c *Client) AddRoots(roots ...*Root) { if len(roots) == 0 { return } - c.changeAndNotify(notificationRootsListChanged, &RootsListChangedParams{}, + changeAndNotify(c, notificationRootsListChanged, &RootsListChangedParams{}, func() bool { c.roots.add(roots...); return true }) } @@ -225,14 +227,14 @@ func (c *Client) AddRoots(roots ...*Root) { // It is not an error to remove a nonexistent root. // TODO: notification func (c *Client) RemoveRoots(uris ...string) { - c.changeAndNotify(notificationRootsListChanged, &RootsListChangedParams{}, + changeAndNotify(c, notificationRootsListChanged, &RootsListChangedParams{}, func() bool { return c.roots.remove(uris...) }) } // changeAndNotify is called when a feature is added or removed. // It calls change, which should do the work and report whether a change actually occurred. // If there was a change, it notifies a snapshot of the sessions. -func (c *Client) changeAndNotify(notification string, params Params, change func() bool) { +func changeAndNotify[P Params](c *Client, notification string, params P, change func() bool) { var sessions []*ClientSession // Lock for the change, but not for the notification. c.mu.Lock() @@ -243,7 +245,7 @@ func (c *Client) changeAndNotify(notification string, params Params, change func notifySessions(sessions, notification, params) } -func (c *Client) listRoots(_ context.Context, _ *ClientSession, _ *ListRootsParams) (*ListRootsResult, error) { +func (c *Client) listRoots(_ context.Context, req *ClientRequest[*ListRootsParams]) (*ListRootsResult, error) { c.mu.Lock() defer c.mu.Unlock() roots := slices.Collect(c.roots.all()) @@ -255,12 +257,12 @@ func (c *Client) listRoots(_ context.Context, _ *ClientSession, _ *ListRootsPara }, nil } -func (c *Client) createMessage(ctx context.Context, cs *ClientSession, params *CreateMessageParams) (*CreateMessageResult, error) { +func (c *Client) createMessage(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { if c.opts.CreateMessageHandler == nil { // TODO: wrap or annotate this error? Pick a standard code? return nil, jsonrpc2.NewError(CodeUnsupportedMethod, "client does not support CreateMessage") } - return c.opts.CreateMessageHandler(ctx, cs, params) + return c.opts.CreateMessageHandler(ctx, req) } // AddSendingMiddleware wraps the current sending method handler using the provided @@ -272,7 +274,7 @@ func (c *Client) createMessage(ctx context.Context, cs *ClientSession, params *C // // Sending middleware is called when a request is sent. It is useful for tasks // such as tracing, metrics, and adding progress tokens. -func (c *Client) AddSendingMiddleware(middleware ...Middleware[*ClientSession]) { +func (c *Client) AddSendingMiddleware(middleware ...Middleware) { c.mu.Lock() defer c.mu.Unlock() addMiddleware(&c.sendingMethodHandler_, middleware) @@ -287,7 +289,7 @@ func (c *Client) AddSendingMiddleware(middleware ...Middleware[*ClientSession]) // // Receiving middleware is called when a request is received. It is useful for tasks // such as authentication, request logging and metrics. -func (c *Client) AddReceivingMiddleware(middleware ...Middleware[*ClientSession]) { +func (c *Client) AddReceivingMiddleware(middleware ...Middleware) { c.mu.Lock() defer c.mu.Unlock() addMiddleware(&c.receivingMethodHandler_, middleware) @@ -299,16 +301,16 @@ func (c *Client) AddReceivingMiddleware(middleware ...Middleware[*ClientSession] // TODO(rfindley): actually load and validate the protocol schema, rather than // curating these method flags. var clientMethodInfos = map[string]methodInfo{ - methodComplete: newMethodInfo(sessionMethod((*ClientSession).Complete), 0), - methodPing: newMethodInfo(sessionMethod((*ClientSession).ping), missingParamsOK), - methodListRoots: newMethodInfo(clientMethod((*Client).listRoots), missingParamsOK), - methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage), 0), - notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK), - notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK), - notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler), notification|missingParamsOK), - notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), notification|missingParamsOK), - notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler), notification), - notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler), notification), + methodComplete: newClientMethodInfo(clientSessionMethod((*ClientSession).Complete), 0), + methodPing: newClientMethodInfo(clientSessionMethod((*ClientSession).ping), missingParamsOK), + methodListRoots: newClientMethodInfo(clientMethod((*Client).listRoots), missingParamsOK), + methodCreateMessage: newClientMethodInfo(clientMethod((*Client).createMessage), 0), + notificationToolListChanged: newClientMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK), + notificationPromptListChanged: newClientMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK), + notificationResourceListChanged: newClientMethodInfo(clientMethod((*Client).callResourceChangedHandler), notification|missingParamsOK), + notificationResourceUpdated: newClientMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), notification|missingParamsOK), + notificationLoggingMessage: newClientMethodInfo(clientMethod((*Client).callLoggingHandler), notification), + notificationProgress: newClientMethodInfo(clientSessionMethod((*ClientSession).callProgressNotificationHandler), notification), } func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { @@ -342,25 +344,29 @@ func (*ClientSession) ping(context.Context, *PingParams) (*emptyResult, error) { return &emptyResult{}, nil } +func newClientRequest[P Params](cs *ClientSession, params P) *ClientRequest[P] { + return &ClientRequest[P]{Session: cs, Params: params} +} + // Ping makes an MCP "ping" request to the server. func (cs *ClientSession) Ping(ctx context.Context, params *PingParams) error { - _, err := handleSend[*emptyResult](ctx, cs, methodPing, orZero[Params](params)) + _, err := handleSend[*emptyResult](ctx, methodPing, newClientRequest(cs, orZero[Params](params))) return err } // ListPrompts lists prompts that are currently available on the server. func (cs *ClientSession) ListPrompts(ctx context.Context, params *ListPromptsParams) (*ListPromptsResult, error) { - return handleSend[*ListPromptsResult](ctx, cs, methodListPrompts, orZero[Params](params)) + return handleSend[*ListPromptsResult](ctx, methodListPrompts, newClientRequest(cs, orZero[Params](params))) } // GetPrompt gets a prompt from the server. func (cs *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams) (*GetPromptResult, error) { - return handleSend[*GetPromptResult](ctx, cs, methodGetPrompt, orZero[Params](params)) + return handleSend[*GetPromptResult](ctx, methodGetPrompt, newClientRequest(cs, orZero[Params](params))) } // ListTools lists tools that are currently available on the server. func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { - return handleSend[*ListToolsResult](ctx, cs, methodListTools, orZero[Params](params)) + return handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params))) } // CallTool calls the tool with the given name and arguments. @@ -373,72 +379,87 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) ( // Avoid sending nil over the wire. params.Arguments = map[string]any{} } - return handleSend[*CallToolResult](ctx, cs, methodCallTool, params) + return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) } func (cs *ClientSession) SetLevel(ctx context.Context, params *SetLevelParams) error { - _, err := handleSend[*emptyResult](ctx, cs, methodSetLevel, orZero[Params](params)) + _, err := handleSend[*emptyResult](ctx, methodSetLevel, newClientRequest(cs, orZero[Params](params))) return err } // ListResources lists the resources that are currently available on the server. func (cs *ClientSession) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) { - return handleSend[*ListResourcesResult](ctx, cs, methodListResources, orZero[Params](params)) + return handleSend[*ListResourcesResult](ctx, methodListResources, newClientRequest(cs, orZero[Params](params))) } // ListResourceTemplates lists the resource templates that are currently available on the server. func (cs *ClientSession) ListResourceTemplates(ctx context.Context, params *ListResourceTemplatesParams) (*ListResourceTemplatesResult, error) { - return handleSend[*ListResourceTemplatesResult](ctx, cs, methodListResourceTemplates, orZero[Params](params)) + return handleSend[*ListResourceTemplatesResult](ctx, methodListResourceTemplates, newClientRequest(cs, orZero[Params](params))) } // ReadResource asks the server to read a resource and return its contents. func (cs *ClientSession) ReadResource(ctx context.Context, params *ReadResourceParams) (*ReadResourceResult, error) { - return handleSend[*ReadResourceResult](ctx, cs, methodReadResource, orZero[Params](params)) + return handleSend[*ReadResourceResult](ctx, methodReadResource, newClientRequest(cs, orZero[Params](params))) } func (cs *ClientSession) Complete(ctx context.Context, params *CompleteParams) (*CompleteResult, error) { - return handleSend[*CompleteResult](ctx, cs, methodComplete, orZero[Params](params)) + return handleSend[*CompleteResult](ctx, methodComplete, newClientRequest(cs, orZero[Params](params))) } // Subscribe sends a "resources/subscribe" request to the server, asking for // notifications when the specified resource changes. func (cs *ClientSession) Subscribe(ctx context.Context, params *SubscribeParams) error { - _, err := handleSend[*emptyResult](ctx, cs, methodSubscribe, orZero[Params](params)) + _, err := handleSend[*emptyResult](ctx, methodSubscribe, newClientRequest(cs, orZero[Params](params))) return err } // Unsubscribe sends a "resources/unsubscribe" request to the server, cancelling // a previous subscription. func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribeParams) error { - _, err := handleSend[*emptyResult](ctx, cs, methodUnsubscribe, orZero[Params](params)) + _, err := handleSend[*emptyResult](ctx, methodUnsubscribe, newClientRequest(cs, orZero[Params](params))) return err } -func (c *Client) callToolChangedHandler(ctx context.Context, s *ClientSession, params *ToolListChangedParams) (Result, error) { - return callNotificationHandler(ctx, c.opts.ToolListChangedHandler, s, params) +func (c *Client) callToolChangedHandler(ctx context.Context, req *ClientRequest[*ToolListChangedParams]) (Result, error) { + if h := c.opts.ToolListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil } -func (c *Client) callPromptChangedHandler(ctx context.Context, s *ClientSession, params *PromptListChangedParams) (Result, error) { - return callNotificationHandler(ctx, c.opts.PromptListChangedHandler, s, params) +func (c *Client) callPromptChangedHandler(ctx context.Context, req *ClientRequest[*PromptListChangedParams]) (Result, error) { + if h := c.opts.PromptListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil } -func (c *Client) callResourceChangedHandler(ctx context.Context, s *ClientSession, params *ResourceListChangedParams) (Result, error) { - return callNotificationHandler(ctx, c.opts.ResourceListChangedHandler, s, params) +func (c *Client) callResourceChangedHandler(ctx context.Context, req *ClientRequest[*ResourceListChangedParams]) (Result, error) { + if h := c.opts.ResourceListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil } -func (c *Client) callResourceUpdatedHandler(ctx context.Context, s *ClientSession, params *ResourceUpdatedNotificationParams) (Result, error) { - return callNotificationHandler(ctx, c.opts.ResourceUpdatedHandler, s, params) +func (c *Client) callResourceUpdatedHandler(ctx context.Context, req *ClientRequest[*ResourceUpdatedNotificationParams]) (Result, error) { + if h := c.opts.ResourceUpdatedHandler; h != nil { + h(ctx, req) + } + return nil, nil } -func (c *Client) callLoggingHandler(ctx context.Context, cs *ClientSession, params *LoggingMessageParams) (Result, error) { +func (c *Client) callLoggingHandler(ctx context.Context, req *ClientRequest[*LoggingMessageParams]) (Result, error) { if h := c.opts.LoggingMessageHandler; h != nil { - h(ctx, cs, params) + h(ctx, req) } return nil, nil } func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, params *ProgressNotificationParams) (Result, error) { - return callNotificationHandler(ctx, cs.client.opts.ProgressNotificationHandler, cs, params) + if h := cs.client.opts.ProgressNotificationHandler; h != nil { + h(ctx, clientRequestFor(cs, params)) + } + return nil, nil } // NotifyProgress sends a progress notification from the client to the server @@ -446,7 +467,7 @@ func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, pa // This can be used if the client is performing a long-running task that was // initiated by the server func (cs *ClientSession) NotifyProgress(ctx context.Context, params *ProgressNotificationParams) error { - return handleNotify(ctx, cs, notificationProgress, params) + return handleNotify(ctx, notificationProgress, newClientRequest(cs, orZero[Params](params))) } // Tools provides an iterator for all tools available on the server, diff --git a/mcp/example_middleware_test.go b/mcp/example_middleware_test.go index c91250c3..56f7428a 100644 --- a/mcp/example_middleware_test.go +++ b/mcp/example_middleware_test.go @@ -29,36 +29,35 @@ func Example_loggingMiddleware() { }, })) - loggingMiddleware := func(next mcp.MethodHandler[*mcp.ServerSession]) mcp.MethodHandler[*mcp.ServerSession] { + loggingMiddleware := func(next mcp.MethodHandler) mcp.MethodHandler { return func( ctx context.Context, - session *mcp.ServerSession, method string, - params mcp.Params, + req mcp.Request, ) (mcp.Result, error) { logger.Info("MCP method started", "method", method, - "session_id", session.ID(), - "has_params", params != nil, + "session_id", req.GetSession().ID(), + "has_params", req.GetParams() != nil, ) start := time.Now() - result, err := next(ctx, session, method, params) + result, err := next(ctx, method, req) duration := time.Since(start) if err != nil { logger.Error("MCP method failed", "method", method, - "session_id", session.ID(), + "session_id", req.GetSession().ID(), "duration_ms", duration.Milliseconds(), "err", err, ) } else { logger.Info("MCP method completed", "method", method, - "session_id", session.ID(), + "session_id", req.GetSession().ID(), "duration_ms", duration.Milliseconds(), "has_result", result != nil, ) @@ -90,10 +89,9 @@ func Example_loggingMiddleware() { }, func( ctx context.Context, - ss *mcp.ServerSession, - params *mcp.CallToolParamsFor[map[string]any], + req *mcp.ServerRequest[*mcp.CallToolParamsFor[map[string]any]], ) (*mcp.CallToolResultFor[any], error) { - name, ok := params.Arguments["name"].(string) + name, ok := req.Params.Arguments["name"].(string) if !ok { return nil, fmt.Errorf("name parameter is required and must be a string") } diff --git a/mcp/example_progress_test.go b/mcp/example_progress_test.go index 6c771e20..304c838a 100644 --- a/mcp/example_progress_test.go +++ b/mcp/example_progress_test.go @@ -21,11 +21,11 @@ func Example_progressMiddleware() { _ = c } -func addProgressToken[S mcp.Session](h mcp.MethodHandler[S]) mcp.MethodHandler[S] { - return func(ctx context.Context, s S, method string, params mcp.Params) (result mcp.Result, err error) { - if rp, ok := params.(mcp.RequestParams); ok { +func addProgressToken[S mcp.Session](h mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) { + if rp, ok := req.GetParams().(mcp.RequestParams); ok { rp.SetProgressToken(nextProgressToken.Add(1)) } - return h(ctx, s, method, params) + return h(ctx, method, req) } } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 9e3eccd0..58b0377e 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -32,11 +32,11 @@ type hiParams struct { // TODO(jba): after schemas are stateless (WIP), this can be a variable. func greetTool() *Tool { return &Tool{Name: "greet", Description: "say hi"} } -func sayHi(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[hiParams]) (*CallToolResultFor[any], error) { - if err := ss.Ping(ctx, nil); err != nil { +func sayHi(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResultFor[any], error) { + if err := req.Session.Ping(ctx, nil); err != nil { return nil, fmt.Errorf("ping failed: %v", err) } - return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + params.Arguments.Name}}}, nil + return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil } var codeReviewPrompt = &Prompt{ @@ -73,16 +73,20 @@ func TestEndToEnd(t *testing.T) { } sopts := &ServerOptions{ - InitializedHandler: func(context.Context, *ServerSession, *InitializedParams) { notificationChans["initialized"] <- 0 }, - RootsListChangedHandler: func(context.Context, *ServerSession, *RootsListChangedParams) { notificationChans["roots"] <- 0 }, - ProgressNotificationHandler: func(context.Context, *ServerSession, *ProgressNotificationParams) { + InitializedHandler: func(context.Context, *ServerRequest[*InitializedParams]) { + notificationChans["initialized"] <- 0 + }, + RootsListChangedHandler: func(context.Context, *ServerRequest[*RootsListChangedParams]) { + notificationChans["roots"] <- 0 + }, + ProgressNotificationHandler: func(context.Context, *ServerRequest[*ProgressNotificationParams]) { notificationChans["progress_server"] <- 0 }, - SubscribeHandler: func(context.Context, *ServerSession, *SubscribeParams) error { + SubscribeHandler: func(context.Context, *ServerRequest[*SubscribeParams]) error { notificationChans["subscribe"] <- 0 return nil }, - UnsubscribeHandler: func(context.Context, *ServerSession, *UnsubscribeParams) error { + UnsubscribeHandler: func(context.Context, *ServerRequest[*UnsubscribeParams]) error { notificationChans["unsubscribe"] <- 0 return nil }, @@ -93,7 +97,7 @@ func TestEndToEnd(t *testing.T) { Description: "say hi", }, sayHi) s.AddTool(&Tool{Name: "fail", InputSchema: &jsonschema.Schema{}}, - func(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { + func(context.Context, *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { return nil, errTestFailure }) s.AddPrompt(codeReviewPrompt, codReviewPromptHandler) @@ -124,19 +128,25 @@ func TestEndToEnd(t *testing.T) { loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging opts := &ClientOptions{ - CreateMessageHandler: func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) { + CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil }, - ToolListChangedHandler: func(context.Context, *ClientSession, *ToolListChangedParams) { notificationChans["tools"] <- 0 }, - PromptListChangedHandler: func(context.Context, *ClientSession, *PromptListChangedParams) { notificationChans["prompts"] <- 0 }, - ResourceListChangedHandler: func(context.Context, *ClientSession, *ResourceListChangedParams) { notificationChans["resources"] <- 0 }, - LoggingMessageHandler: func(_ context.Context, _ *ClientSession, lm *LoggingMessageParams) { - loggingMessages <- lm + ToolListChangedHandler: func(context.Context, *ClientRequest[*ToolListChangedParams]) { + notificationChans["tools"] <- 0 + }, + PromptListChangedHandler: func(context.Context, *ClientRequest[*PromptListChangedParams]) { + notificationChans["prompts"] <- 0 + }, + ResourceListChangedHandler: func(context.Context, *ClientRequest[*ResourceListChangedParams]) { + notificationChans["resources"] <- 0 + }, + LoggingMessageHandler: func(_ context.Context, req *ClientRequest[*LoggingMessageParams]) { + loggingMessages <- req.Params }, - ProgressNotificationHandler: func(context.Context, *ClientSession, *ProgressNotificationParams) { + ProgressNotificationHandler: func(context.Context, *ClientRequest[*ProgressNotificationParams]) { notificationChans["progress_client"] <- 0 }, - ResourceUpdatedHandler: func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams) { + ResourceUpdatedHandler: func(context.Context, *ClientRequest[*ResourceUpdatedNotificationParams]) { notificationChans["resource_updated"] <- 0 }, } @@ -500,8 +510,8 @@ var embeddedResources = map[string]string{ "info": "This is the MCP test server.", } -func handleEmbeddedResource(_ context.Context, _ *ServerSession, params *ReadResourceParams) (*ReadResourceResult, error) { - u, err := url.Parse(params.URI) +func handleEmbeddedResource(_ context.Context, req *ServerRequest[*ReadResourceParams]) (*ReadResourceResult, error) { + u, err := url.Parse(req.Params.URI) if err != nil { return nil, err } @@ -515,7 +525,7 @@ func handleEmbeddedResource(_ context.Context, _ *ServerSession, params *ReadRes } return &ReadResourceResult{ Contents: []*ResourceContents{ - {URI: params.URI, MIMEType: "text/plain", Text: text}, + {URI: req.Params.URI, MIMEType: "text/plain", Text: text}, }, }, nil } @@ -637,7 +647,7 @@ func TestCancellation(t *testing.T) { cancelled = make(chan struct{}, 1) // don't block the request ) - slowRequest := func(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { + slowRequest := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { start <- struct{}{} select { case <-ctx.Done(): @@ -816,17 +826,17 @@ func TestNoJSONNull(t *testing.T) { // traceCalls creates a middleware function that prints the method before and after each call // with the given prefix. -func traceCalls[S Session](w io.Writer, prefix string) Middleware[S] { - return func(h MethodHandler[S]) MethodHandler[S] { - return func(ctx context.Context, sess S, method string, params Params) (Result, error) { +func traceCalls[S Session](w io.Writer, prefix string) Middleware { + return func(h MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { fmt.Fprintf(w, "%s >%s\n", prefix, method) defer fmt.Fprintf(w, "%s <%s\n", prefix, method) - return h(ctx, sess, method, params) + return h(ctx, method, req) } } } -func nopHandler(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { +func nopHandler(context.Context, *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { return nil, nil } diff --git a/mcp/resource.go b/mcp/resource.go index 590e0672..5445715b 100644 --- a/mcp/resource.go +++ b/mcp/resource.go @@ -35,7 +35,7 @@ type serverResourceTemplate struct { // A ResourceHandler is a function that reads a resource. // It will be called when the client calls [ClientSession.ReadResource]. // If it cannot find the resource, it should return the result of calling [ResourceNotFoundError]. -type ResourceHandler func(context.Context, *ServerSession, *ReadResourceParams) (*ReadResourceResult, error) +type ResourceHandler func(context.Context, *ServerRequest[*ReadResourceParams]) (*ReadResourceResult, error) // ResourceNotFoundError returns an error indicating that a resource being read could // not be found. diff --git a/mcp/server.go b/mcp/server.go index 89f3b6c9..e39372dc 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -41,8 +41,8 @@ type Server struct { resources *featureSet[*serverResource] resourceTemplates *featureSet[*serverResourceTemplate] sessions []*ServerSession - sendingMethodHandler_ MethodHandler[*ServerSession] - receivingMethodHandler_ MethodHandler[*ServerSession] + sendingMethodHandler_ MethodHandler + receivingMethodHandler_ MethodHandler resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool } @@ -51,24 +51,24 @@ type ServerOptions struct { // Optional instructions for connected clients. Instructions string // If non-nil, called when "notifications/initialized" is received. - InitializedHandler func(context.Context, *ServerSession, *InitializedParams) + InitializedHandler func(context.Context, *ServerRequest[*InitializedParams]) // PageSize is the maximum number of items to return in a single page for // list methods (e.g. ListTools). PageSize int // If non-nil, called when "notifications/roots/list_changed" is received. - RootsListChangedHandler func(context.Context, *ServerSession, *RootsListChangedParams) + RootsListChangedHandler func(context.Context, *ServerRequest[*RootsListChangedParams]) // If non-nil, called when "notifications/progress" is received. - ProgressNotificationHandler func(context.Context, *ServerSession, *ProgressNotificationParams) + ProgressNotificationHandler func(context.Context, *ServerRequest[*ProgressNotificationParams]) // If non-nil, called when "completion/complete" is received. - CompletionHandler func(context.Context, *ServerSession, *CompleteParams) (*CompleteResult, error) + CompletionHandler func(context.Context, *ServerRequest[*CompleteParams]) (*CompleteResult, error) // If non-zero, defines an interval for regular "ping" requests. // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. KeepAlive time.Duration // Function called when a client session subscribes to a resource. - SubscribeHandler func(context.Context, *ServerSession, *SubscribeParams) error + SubscribeHandler func(context.Context, *ServerRequest[*SubscribeParams]) error // Function called when a client session unsubscribes from a resource. - UnsubscribeHandler func(context.Context, *ServerSession, *UnsubscribeParams) error + UnsubscribeHandler func(context.Context, *ServerRequest[*UnsubscribeParams]) error // If true, advertises the prompts capability during initialization, // even if no prompts have been registered. HasPrompts bool @@ -258,11 +258,11 @@ func (s *Server) capabilities() *serverCapabilities { return caps } -func (s *Server) complete(ctx context.Context, ss *ServerSession, params *CompleteParams) (Result, error) { +func (s *Server) complete(ctx context.Context, req *ServerRequest[*CompleteParams]) (Result, error) { if s.opts.CompletionHandler == nil { return nil, jsonrpc2.ErrMethodNotFound } - return s.opts.CompletionHandler(ctx, ss, params) + return s.opts.CompletionHandler(ctx, req) } // changeAndNotify is called when a feature is added or removed. @@ -287,13 +287,13 @@ func (s *Server) Sessions() iter.Seq[*ServerSession] { return slices.Values(clients) } -func (s *Server) listPrompts(_ context.Context, _ *ServerSession, params *ListPromptsParams) (*ListPromptsResult, error) { +func (s *Server) listPrompts(_ context.Context, req *ServerRequest[*ListPromptsParams]) (*ListPromptsResult, error) { s.mu.Lock() defer s.mu.Unlock() - if params == nil { - params = &ListPromptsParams{} + if req.Params == nil { + req.Params = &ListPromptsParams{} } - return paginateList(s.prompts, s.opts.PageSize, params, &ListPromptsResult{}, func(res *ListPromptsResult, prompts []*serverPrompt) { + return paginateList(s.prompts, s.opts.PageSize, req.Params, &ListPromptsResult{}, func(res *ListPromptsResult, prompts []*serverPrompt) { res.Prompts = []*Prompt{} // avoid JSON null for _, p := range prompts { res.Prompts = append(res.Prompts, p.prompt) @@ -301,24 +301,24 @@ func (s *Server) listPrompts(_ context.Context, _ *ServerSession, params *ListPr }) } -func (s *Server) getPrompt(ctx context.Context, ss *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { +func (s *Server) getPrompt(ctx context.Context, req *ServerRequest[*GetPromptParams]) (*GetPromptResult, error) { s.mu.Lock() - prompt, ok := s.prompts.get(params.Name) + prompt, ok := s.prompts.get(req.Params.Name) s.mu.Unlock() if !ok { // TODO: surface the error code over the wire, instead of flattening it into the string. - return nil, fmt.Errorf("%s: unknown prompt %q", jsonrpc2.ErrInvalidParams, params.Name) + return nil, fmt.Errorf("%s: unknown prompt %q", jsonrpc2.ErrInvalidParams, req.Params.Name) } - return prompt.handler(ctx, ss, params) + return prompt.handler(ctx, req.Session, req.Params) } -func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListToolsParams) (*ListToolsResult, error) { +func (s *Server) listTools(_ context.Context, req *ServerRequest[*ListToolsParams]) (*ListToolsResult, error) { s.mu.Lock() defer s.mu.Unlock() - if params == nil { - params = &ListToolsParams{} + if req.Params == nil { + req.Params = &ListToolsParams{} } - return paginateList(s.tools, s.opts.PageSize, params, &ListToolsResult{}, func(res *ListToolsResult, tools []*serverTool) { + return paginateList(s.tools, s.opts.PageSize, req.Params, &ListToolsResult{}, func(res *ListToolsResult, tools []*serverTool) { res.Tools = []*Tool{} // avoid JSON null for _, t := range tools { res.Tools = append(res.Tools, t.tool) @@ -326,23 +326,23 @@ func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListTool }) } -func (s *Server) callTool(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) { +func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParamsFor[json.RawMessage]]) (*CallToolResult, error) { s.mu.Lock() - st, ok := s.tools.get(params.Name) + st, ok := s.tools.get(req.Params.Name) s.mu.Unlock() if !ok { - return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, params.Name) + return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, req.Params.Name) } - return st.handler(ctx, cc, params) + return st.handler(ctx, req) } -func (s *Server) listResources(_ context.Context, _ *ServerSession, params *ListResourcesParams) (*ListResourcesResult, error) { +func (s *Server) listResources(_ context.Context, req *ServerRequest[*ListResourcesParams]) (*ListResourcesResult, error) { s.mu.Lock() defer s.mu.Unlock() - if params == nil { - params = &ListResourcesParams{} + if req.Params == nil { + req.Params = &ListResourcesParams{} } - return paginateList(s.resources, s.opts.PageSize, params, &ListResourcesResult{}, func(res *ListResourcesResult, resources []*serverResource) { + return paginateList(s.resources, s.opts.PageSize, req.Params, &ListResourcesResult{}, func(res *ListResourcesResult, resources []*serverResource) { res.Resources = []*Resource{} // avoid JSON null for _, r := range resources { res.Resources = append(res.Resources, r.resource) @@ -350,13 +350,13 @@ func (s *Server) listResources(_ context.Context, _ *ServerSession, params *List }) } -func (s *Server) listResourceTemplates(_ context.Context, _ *ServerSession, params *ListResourceTemplatesParams) (*ListResourceTemplatesResult, error) { +func (s *Server) listResourceTemplates(_ context.Context, req *ServerRequest[*ListResourceTemplatesParams]) (*ListResourceTemplatesResult, error) { s.mu.Lock() defer s.mu.Unlock() - if params == nil { - params = &ListResourceTemplatesParams{} + if req.Params == nil { + req.Params = &ListResourceTemplatesParams{} } - return paginateList(s.resourceTemplates, s.opts.PageSize, params, &ListResourceTemplatesResult{}, + return paginateList(s.resourceTemplates, s.opts.PageSize, req.Params, &ListResourceTemplatesResult{}, func(res *ListResourceTemplatesResult, rts []*serverResourceTemplate) { res.ResourceTemplates = []*ResourceTemplate{} // avoid JSON null for _, rt := range rts { @@ -365,8 +365,8 @@ func (s *Server) listResourceTemplates(_ context.Context, _ *ServerSession, para }) } -func (s *Server) readResource(ctx context.Context, ss *ServerSession, params *ReadResourceParams) (*ReadResourceResult, error) { - uri := params.URI +func (s *Server) readResource(ctx context.Context, req *ServerRequest[*ReadResourceParams]) (*ReadResourceResult, error) { + uri := req.Params.URI // Look up the resource URI in the lists of resources and resource templates. // This is a security check as well as an information lookup. handler, mimeType, ok := s.lookupResourceHandler(uri) @@ -375,7 +375,7 @@ func (s *Server) readResource(ctx context.Context, ss *ServerSession, params *Re // Treat an unregistered resource the same as a registered one that couldn't be found. return nil, ResourceNotFoundError(uri) } - res, err := handler(ctx, ss, params) + res, err := handler(ctx, req) if err != nil { return nil, err } @@ -430,11 +430,11 @@ func fileResourceHandler(dir string) ResourceHandler { if err != nil { panic(err) } - return func(ctx context.Context, ss *ServerSession, params *ReadResourceParams) (_ *ReadResourceResult, err error) { - defer util.Wrapf(&err, "reading resource %s", params.URI) + return func(ctx context.Context, req *ServerRequest[*ReadResourceParams]) (_ *ReadResourceResult, err error) { + defer util.Wrapf(&err, "reading resource %s", req.Params.URI) // TODO: use a memoizing API here. - rootRes, err := ss.ListRoots(ctx, nil) + rootRes, err := req.Session.ListRoots(ctx, nil) if err != nil { return nil, fmt.Errorf("listing roots: %w", err) } @@ -442,13 +442,13 @@ func fileResourceHandler(dir string) ResourceHandler { if err != nil { return nil, err } - data, err := readFileResource(params.URI, dirFilepath, roots) + data, err := readFileResource(req.Params.URI, dirFilepath, roots) if err != nil { return nil, err } // TODO(jba): figure out mime type. Omit for now: Server.readResource will fill it in. return &ReadResourceResult{Contents: []*ResourceContents{ - {URI: params.URI, Blob: data}, + {URI: req.Params.URI, Blob: data}, }}, nil } } @@ -465,39 +465,39 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot return nil } -func (s *Server) subscribe(ctx context.Context, ss *ServerSession, params *SubscribeParams) (*emptyResult, error) { +func (s *Server) subscribe(ctx context.Context, req *ServerRequest[*SubscribeParams]) (*emptyResult, error) { if s.opts.SubscribeHandler == nil { return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound) } - if err := s.opts.SubscribeHandler(ctx, ss, params); err != nil { + if err := s.opts.SubscribeHandler(ctx, req); err != nil { return nil, err } s.mu.Lock() defer s.mu.Unlock() - if s.resourceSubscriptions[params.URI] == nil { - s.resourceSubscriptions[params.URI] = make(map[*ServerSession]bool) + if s.resourceSubscriptions[req.Params.URI] == nil { + s.resourceSubscriptions[req.Params.URI] = make(map[*ServerSession]bool) } - s.resourceSubscriptions[params.URI][ss] = true + s.resourceSubscriptions[req.Params.URI][req.Session] = true return &emptyResult{}, nil } -func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *UnsubscribeParams) (*emptyResult, error) { +func (s *Server) unsubscribe(ctx context.Context, req *ServerRequest[*UnsubscribeParams]) (*emptyResult, error) { if s.opts.UnsubscribeHandler == nil { return nil, jsonrpc2.ErrMethodNotFound } - if err := s.opts.UnsubscribeHandler(ctx, ss, params); err != nil { + if err := s.opts.UnsubscribeHandler(ctx, req); err != nil { return nil, err } s.mu.Lock() defer s.mu.Unlock() - if subscribedSessions, ok := s.resourceSubscriptions[params.URI]; ok { - delete(subscribedSessions, ss) + if subscribedSessions, ok := s.resourceSubscriptions[req.Params.URI]; ok { + delete(subscribedSessions, req.Session) if len(subscribedSessions) == 0 { - delete(s.resourceSubscriptions, params.URI) + delete(s.resourceSubscriptions, req.Params.URI) } } @@ -611,15 +611,24 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar if wasInitd { return nil, fmt.Errorf("duplicate %q received", notificationInitialized) } - return callNotificationHandler(ctx, ss.server.opts.InitializedHandler, ss, params) + if h := ss.server.opts.InitializedHandler; h != nil { + h(ctx, serverRequestFor(ss, params)) + } + return nil, nil } -func (s *Server) callRootsListChangedHandler(ctx context.Context, ss *ServerSession, params *RootsListChangedParams) (Result, error) { - return callNotificationHandler(ctx, s.opts.RootsListChangedHandler, ss, params) +func (s *Server) callRootsListChangedHandler(ctx context.Context, req *ServerRequest[*RootsListChangedParams]) (Result, error) { + if h := s.opts.RootsListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil } -func (ss *ServerSession) callProgressNotificationHandler(ctx context.Context, params *ProgressNotificationParams) (Result, error) { - return callNotificationHandler(ctx, ss.server.opts.ProgressNotificationHandler, ss, params) +func (ss *ServerSession) callProgressNotificationHandler(ctx context.Context, p *ProgressNotificationParams) (Result, error) { + if h := ss.server.opts.ProgressNotificationHandler; h != nil { + h(ctx, serverRequestFor(ss, p)) + } + return nil, nil } // NotifyProgress sends a progress notification from the server to the client @@ -627,7 +636,11 @@ func (ss *ServerSession) callProgressNotificationHandler(ctx context.Context, pa // This is typically used to report on the status of a long-running request // that was initiated by the client. func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNotificationParams) error { - return handleNotify(ctx, ss, notificationProgress, params) + return handleNotify(ctx, notificationProgress, newServerRequest(ss, orZero[Params](params))) +} + +func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] { + return &ServerRequest[P]{Session: ss, Params: params} } // A ServerSession is a logical connection from a single MCP client. Its @@ -665,18 +678,18 @@ func (ss *ServerSession) ID() string { // Ping pings the client. func (ss *ServerSession) Ping(ctx context.Context, params *PingParams) error { - _, err := handleSend[*emptyResult](ctx, ss, methodPing, orZero[Params](params)) + _, err := handleSend[*emptyResult](ctx, methodPing, newServerRequest(ss, orZero[Params](params))) return err } // ListRoots lists the client roots. func (ss *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) (*ListRootsResult, error) { - return handleSend[*ListRootsResult](ctx, ss, methodListRoots, orZero[Params](params)) + return handleSend[*ListRootsResult](ctx, methodListRoots, newServerRequest(ss, orZero[Params](params))) } // CreateMessage sends a sampling request to the client. func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessageParams) (*CreateMessageResult, error) { - return handleSend[*CreateMessageResult](ctx, ss, methodCreateMessage, orZero[Params](params)) + return handleSend[*CreateMessageResult](ctx, methodCreateMessage, newServerRequest(ss, orZero[Params](params))) } // Log sends a log message to the client. @@ -695,7 +708,7 @@ func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) if compareLevels(params.Level, logLevel) < 0 { return nil } - return handleNotify(ctx, ss, notificationLoggingMessage, params) + return handleNotify(ctx, notificationLoggingMessage, newServerRequest(ss, orZero[Params](params))) } // AddSendingMiddleware wraps the current sending method handler using the provided @@ -707,7 +720,7 @@ func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) // // Sending middleware is called when a request is sent. It is useful for tasks // such as tracing, metrics, and adding progress tokens. -func (s *Server) AddSendingMiddleware(middleware ...Middleware[*ServerSession]) { +func (s *Server) AddSendingMiddleware(middleware ...Middleware) { s.mu.Lock() defer s.mu.Unlock() addMiddleware(&s.sendingMethodHandler_, middleware) @@ -722,7 +735,7 @@ func (s *Server) AddSendingMiddleware(middleware ...Middleware[*ServerSession]) // // Receiving middleware is called when a request is received. It is useful for tasks // such as authentication, request logging and metrics. -func (s *Server) AddReceivingMiddleware(middleware ...Middleware[*ServerSession]) { +func (s *Server) AddReceivingMiddleware(middleware ...Middleware) { s.mu.Lock() defer s.mu.Unlock() addMiddleware(&s.receivingMethodHandler_, middleware) @@ -734,22 +747,22 @@ func (s *Server) AddReceivingMiddleware(middleware ...Middleware[*ServerSession] // TODO(rfindley): actually load and validate the protocol schema, rather than // curating these method flags. var serverMethodInfos = map[string]methodInfo{ - methodComplete: newMethodInfo(serverMethod((*Server).complete), 0), - methodInitialize: newMethodInfo(sessionMethod((*ServerSession).initialize), 0), - methodPing: newMethodInfo(sessionMethod((*ServerSession).ping), missingParamsOK), - methodListPrompts: newMethodInfo(serverMethod((*Server).listPrompts), missingParamsOK), - methodGetPrompt: newMethodInfo(serverMethod((*Server).getPrompt), 0), - methodListTools: newMethodInfo(serverMethod((*Server).listTools), missingParamsOK), - methodCallTool: newMethodInfo(serverMethod((*Server).callTool), 0), - methodListResources: newMethodInfo(serverMethod((*Server).listResources), missingParamsOK), - methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates), missingParamsOK), - methodReadResource: newMethodInfo(serverMethod((*Server).readResource), 0), - methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel), 0), - methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe), 0), - methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe), 0), - notificationInitialized: newMethodInfo(sessionMethod((*ServerSession).initialized), notification|missingParamsOK), - notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK), - notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler), notification), + methodComplete: newServerMethodInfo(serverMethod((*Server).complete), 0), + methodInitialize: newServerMethodInfo(serverSessionMethod((*ServerSession).initialize), 0), + methodPing: newServerMethodInfo(serverSessionMethod((*ServerSession).ping), missingParamsOK), + methodListPrompts: newServerMethodInfo(serverMethod((*Server).listPrompts), missingParamsOK), + methodGetPrompt: newServerMethodInfo(serverMethod((*Server).getPrompt), 0), + methodListTools: newServerMethodInfo(serverMethod((*Server).listTools), missingParamsOK), + methodCallTool: newServerMethodInfo(serverMethod((*Server).callTool), 0), + methodListResources: newServerMethodInfo(serverMethod((*Server).listResources), missingParamsOK), + methodListResourceTemplates: newServerMethodInfo(serverMethod((*Server).listResourceTemplates), missingParamsOK), + methodReadResource: newServerMethodInfo(serverMethod((*Server).readResource), 0), + methodSetLevel: newServerMethodInfo(serverSessionMethod((*ServerSession).setLevel), 0), + methodSubscribe: newServerMethodInfo(serverMethod((*Server).subscribe), 0), + methodUnsubscribe: newServerMethodInfo(serverMethod((*Server).unsubscribe), 0), + notificationInitialized: newServerMethodInfo(serverSessionMethod((*ServerSession).initialized), notification|missingParamsOK), + notificationRootsListChanged: newServerMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK), + notificationProgress: newServerMethodInfo(serverSessionMethod((*ServerSession).callProgressNotificationHandler), notification), } func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos } @@ -757,15 +770,17 @@ func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return cli func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { return serverMethodInfos } func (ss *ServerSession) sendingMethodHandler() methodHandler { - ss.server.mu.Lock() - defer ss.server.mu.Unlock() - return ss.server.sendingMethodHandler_ + s := ss.server + s.mu.Lock() + defer s.mu.Unlock() + return s.sendingMethodHandler_ } func (ss *ServerSession) receivingMethodHandler() methodHandler { - ss.server.mu.Lock() - defer ss.server.mu.Unlock() - return ss.server.receivingMethodHandler_ + s := ss.server + s.mu.Lock() + defer s.mu.Unlock() + return s.receivingMethodHandler_ } // getConn implements [session.getConn]. @@ -808,13 +823,14 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam version = latestProtocolVersion } + s := ss.server return &InitializeResult{ // TODO(rfindley): alter behavior when falling back to an older version: // reject unsupported features. ProtocolVersion: version, - Capabilities: ss.server.capabilities(), - Instructions: ss.server.opts.Instructions, - ServerInfo: ss.server.impl, + Capabilities: s.capabilities(), + Instructions: s.opts.Instructions, + ServerInfo: s.impl, }, nil } diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index 241008e9..f735b84e 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -16,10 +16,10 @@ type SayHiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[SayHiParams]) (*mcp.CallToolResultFor[any], error) { +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[SayHiParams]]) (*mcp.CallToolResultFor[any], error) { return &mcp.CallToolResultFor[any]{ Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + params.Arguments.Name}, + &mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}, }, }, nil } diff --git a/mcp/server_test.go b/mcp/server_test.go index 6415decc..202ab5d9 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -281,10 +281,10 @@ func TestServerCapabilities(t *testing.T) { s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) }, serverOpts: ServerOptions{ - SubscribeHandler: func(ctx context.Context, _ *ServerSession, sp *SubscribeParams) error { + SubscribeHandler: func(context.Context, *ServerRequest[*SubscribeParams]) error { return nil }, - UnsubscribeHandler: func(ctx context.Context, _ *ServerSession, up *UnsubscribeParams) error { + UnsubscribeHandler: func(context.Context, *ServerRequest[*UnsubscribeParams]) error { return nil }, }, @@ -307,7 +307,7 @@ func TestServerCapabilities(t *testing.T) { name: "With completions", configureServer: func(s *Server) {}, serverOpts: ServerOptions{ - CompletionHandler: func(ctx context.Context, ss *ServerSession, params *CompleteParams) (*CompleteResult, error) { + CompletionHandler: func(context.Context, *ServerRequest[*CompleteParams]) (*CompleteResult, error) { return nil, nil }, }, @@ -325,13 +325,13 @@ func TestServerCapabilities(t *testing.T) { s.AddTool(tool, nil) }, serverOpts: ServerOptions{ - SubscribeHandler: func(ctx context.Context, _ *ServerSession, sp *SubscribeParams) error { + SubscribeHandler: func(context.Context, *ServerRequest[*SubscribeParams]) error { return nil }, - UnsubscribeHandler: func(ctx context.Context, _ *ServerSession, up *UnsubscribeParams) error { + UnsubscribeHandler: func(context.Context, *ServerRequest[*UnsubscribeParams]) error { return nil }, - CompletionHandler: func(ctx context.Context, ss *ServerSession, params *CompleteParams) (*CompleteResult, error) { + CompletionHandler: func(context.Context, *ServerRequest[*CompleteParams]) (*CompleteResult, error) { return nil, nil }, }, diff --git a/mcp/shared.go b/mcp/shared.go index e3688641..ca062214 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -36,7 +36,7 @@ var supportedProtocolVersions = []string{ // A MethodHandler handles MCP messages. // For methods, exactly one of the return values must be nil. // For notifications, both must be nil. -type MethodHandler[S Session] func(ctx context.Context, _ S, method string, params Params) (result Result, err error) +type MethodHandler func(ctx context.Context, method string, req Request) (result Result, err error) // A methodHandler is a MethodHandler[Session] for some session. // We need to give up type safety here, or we will end up with a type cycle somewhere @@ -46,7 +46,6 @@ type methodHandler any // MethodHandler[*ClientSession] | MethodHandler[*ServerS // A Session is either a [ClientSession] or a [ServerSession]. type Session interface { - *ClientSession | *ServerSession // ID returns the session ID, or the empty string if there is none. ID() string @@ -58,29 +57,29 @@ type Session interface { } // Middleware is a function from [MethodHandler] to [MethodHandler]. -type Middleware[S Session] func(MethodHandler[S]) MethodHandler[S] +type Middleware func(MethodHandler) MethodHandler // addMiddleware wraps the handler in the middleware functions. -func addMiddleware[S Session](handlerp *MethodHandler[S], middleware []Middleware[S]) { +func addMiddleware(handlerp *MethodHandler, middleware []Middleware) { for _, m := range slices.Backward(middleware) { *handlerp = m(*handlerp) } } -func defaultSendingMethodHandler[S Session](ctx context.Context, session S, method string, params Params) (Result, error) { - info, ok := session.sendingMethodInfos()[method] +func defaultSendingMethodHandler[S Session](ctx context.Context, method string, req Request) (Result, error) { + info, ok := req.GetSession().sendingMethodInfos()[method] if !ok { // This can be called from user code, with an arbitrary value for method. return nil, jsonrpc2.ErrNotHandled } // Notifications don't have results. if strings.HasPrefix(method, "notifications/") { - return nil, session.getConn().Notify(ctx, method, params) + return nil, req.GetSession().getConn().Notify(ctx, method, req.GetParams()) } // Create the result to unmarshal into. // The concrete type of the result is the return type of the receiving function. res := info.newResult() - if err := call(ctx, session.getConn(), method, params, res); err != nil { + if err := call(ctx, req.GetSession().getConn(), method, req.GetParams(), res); err != nil { return nil, err } return res, nil @@ -95,16 +94,16 @@ func orZero[T any, P *U, U any](p P) T { return any(p).(T) } -func handleNotify[S Session](ctx context.Context, session S, method string, params Params) error { - mh := session.sendingMethodHandler().(MethodHandler[S]) - _, err := mh(ctx, session, method, params) +func handleNotify(ctx context.Context, method string, req Request) error { + mh := req.GetSession().sendingMethodHandler().(MethodHandler) + _, err := mh(ctx, method, req) return err } -func handleSend[R Result, S Session](ctx context.Context, s S, method string, params Params) (R, error) { - mh := s.sendingMethodHandler().(MethodHandler[S]) +func handleSend[R Result](ctx context.Context, method string, req Request) (R, error) { + mh := req.GetSession().sendingMethodHandler().(MethodHandler) // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. - res, err := mh(ctx, s, method, params) + res, err := mh(ctx, method, req) if err != nil { var z R return z, err @@ -113,28 +112,29 @@ func handleSend[R Result, S Session](ctx context.Context, s S, method string, pa } // defaultReceivingMethodHandler is the initial MethodHandler for servers and clients, before being wrapped by middleware. -func defaultReceivingMethodHandler[S Session](ctx context.Context, session S, method string, params Params) (Result, error) { - info, ok := session.receivingMethodInfos()[method] +func defaultReceivingMethodHandler[S Session](ctx context.Context, method string, req Request) (Result, error) { + info, ok := req.GetSession().receivingMethodInfos()[method] if !ok { // This can be called from user code, with an arbitrary value for method. return nil, jsonrpc2.ErrNotHandled } - return info.handleMethod.(MethodHandler[S])(ctx, session, method, params) + return info.handleMethod.(MethodHandler)(ctx, method, req) } -func handleReceive[S Session](ctx context.Context, session S, req *jsonrpc.Request) (Result, error) { - info, err := checkRequest(req, session.receivingMethodInfos()) +func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Request) (Result, error) { + info, err := checkRequest(jreq, session.receivingMethodInfos()) if err != nil { return nil, err } - params, err := info.unmarshalParams(req.Params) + params, err := info.unmarshalParams(jreq.Params) if err != nil { - return nil, fmt.Errorf("handling '%s': %w", req.Method, err) + return nil, fmt.Errorf("handling '%s': %w", jreq.Method, err) } - mh := session.receivingMethodHandler().(MethodHandler[S]) + mh := session.receivingMethodHandler().(MethodHandler) + req := info.newRequest(session, params) // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. - res, err := mh(ctx, session, req.Method, params) + res, err := mh(ctx, jreq.Method, req) if err != nil { return nil, err } @@ -179,6 +179,7 @@ type methodInfo struct { // Unmarshal params from the wire into a Params struct. // Used on the receive side. unmarshalParams func(json.RawMessage) (Params, error) + newRequest func(Session, Params) Request // Run the code when a call to the method is received. // Used on the receive side. handleMethod methodHandler @@ -194,7 +195,10 @@ type methodInfo struct { // - R: results // A typedMethodHandler is like a MethodHandler, but with type information. -type typedMethodHandler[S Session, P Params, R Result] func(context.Context, S, P) (R, error) +type ( + typedClientMethodHandler[P Params, R Result] func(context.Context, *ClientRequest[P]) (R, error) + typedServerMethodHandler[P Params, R Result] func(context.Context, *ServerRequest[P]) (R, error) +) type paramsPtr[T any] interface { *T @@ -208,11 +212,45 @@ const ( missingParamsOK // params may be missing or null ) +func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHandler[P, R], flags methodFlags) methodInfo { + mi := newMethodInfo[P, R](flags) + mi.newRequest = func(s Session, p Params) Request { + r := &ClientRequest[P]{Session: s.(*ClientSession)} + if p != nil { + r.Params = p.(P) + } + return r + } + mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) { + return d(ctx, req.(*ClientRequest[P])) + }) + return mi +} + +func newServerMethodInfo[P paramsPtr[T], R Result, T any](d typedServerMethodHandler[P, R], flags methodFlags) methodInfo { + mi := newMethodInfo[P, R](flags) + mi.newRequest = func(s Session, p Params) Request { + r := &ServerRequest[P]{Session: s.(*ServerSession)} + if p != nil { + r.Params = p.(P) + } + return r + } + mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) { + rf := &ServerRequest[P]{Session: req.GetSession().(*ServerSession)} + if req.GetParams() != nil { + rf.Params = req.GetParams().(P) + } + return d(ctx, rf) + }) + return mi +} + // newMethodInfo creates a methodInfo from a typedMethodHandler. // // If isRequest is set, the method is treated as a request rather than a // notification. -func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHandler[S, P, R], flags methodFlags) methodInfo { +func newMethodInfo[P paramsPtr[T], R Result, T any](flags methodFlags) methodInfo { return methodInfo{ flags: flags, unmarshalParams: func(m json.RawMessage) (Params, error) { @@ -234,12 +272,6 @@ func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHand } return orZero[Params](p), nil }, - handleMethod: MethodHandler[S](func(ctx context.Context, session S, _ string, params Params) (Result, error) { - if params == nil { - return d(ctx, session, nil) - } - return d(ctx, session, params.(P)) - }), // newResult is used on the send side, to construct the value to unmarshal the result into. // R is a pointer to a result struct. There is no way to "unpointer" it without reflection. // TODO(jba): explore generic approaches to this, perhaps by treating R in @@ -250,26 +282,33 @@ func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHand // serverMethod is glue for creating a typedMethodHandler from a method on Server. func serverMethod[P Params, R Result]( - f func(*Server, context.Context, *ServerSession, P) (R, error), -) typedMethodHandler[*ServerSession, P, R] { - return func(ctx context.Context, ss *ServerSession, p P) (R, error) { - return f(ss.server, ctx, ss, p) + f func(*Server, context.Context, *ServerRequest[P]) (R, error), +) typedServerMethodHandler[P, R] { + return func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return f(req.Session.server, ctx, req) } } // clientMethod is glue for creating a typedMethodHandler from a method on Client. func clientMethod[P Params, R Result]( - f func(*Client, context.Context, *ClientSession, P) (R, error), -) typedMethodHandler[*ClientSession, P, R] { - return func(ctx context.Context, cs *ClientSession, p P) (R, error) { - return f(cs.client, ctx, cs, p) + f func(*Client, context.Context, *ClientRequest[P]) (R, error), +) typedClientMethodHandler[P, R] { + return func(ctx context.Context, req *ClientRequest[P]) (R, error) { + return f(req.Session.client, ctx, req) + } +} + +// serverSessionMethod is glue for creating a typedServerMethodHandler from a method on ServerSession. +func serverSessionMethod[P Params, R Result](f func(*ServerSession, context.Context, P) (R, error)) typedServerMethodHandler[P, R] { + return func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return f(req.GetSession().(*ServerSession), ctx, req.Params) } } -// sessionMethod is glue for creating a typedMethodHandler from a method on ServerSession. -func sessionMethod[S Session, P Params, R Result](f func(S, context.Context, P) (R, error)) typedMethodHandler[S, P, R] { - return func(ctx context.Context, sess S, p P) (R, error) { - return f(sess, ctx, p) +// clientSessionMethod is glue for creating a typedMethodHandler from a method on ServerSession. +func clientSessionMethod[P Params, R Result](f func(*ClientSession, context.Context, P) (R, error)) typedClientMethodHandler[P, R] { + return func(ctx context.Context, req *ClientRequest[P]) (R, error) { + return f(req.GetSession().(*ClientSession), ctx, req.Params) } } @@ -282,16 +321,9 @@ const ( CodeUnsupportedMethod = -31001 ) -func callNotificationHandler[S Session, P any](ctx context.Context, h func(context.Context, S, *P), sess S, params *P) (Result, error) { - if h != nil { - h(ctx, sess, params) - } - return nil, nil -} - // notifySessions calls Notify on all the sessions. // Should be called on a copy of the peer sessions. -func notifySessions[S Session](sessions []S, method string, params Params) { +func notifySessions[S Session, P Params](sessions []S, method string, params P) { if sessions == nil { return } @@ -299,13 +331,25 @@ func notifySessions[S Session](sessions []S, method string, params Params) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() for _, s := range sessions { - if err := handleNotify(ctx, s, method, params); err != nil { + req := newRequest(s, params) + if err := handleNotify(ctx, method, req); err != nil { // TODO(jba): surface this error better log.Printf("calling %s: %v", method, err) } } } +func newRequest[S Session, P Params](s S, p P) Request { + switch s := any(s).(type) { + case *ClientSession: + return &ClientRequest[P]{Session: s, Params: p} + case *ServerSession: + return &ServerRequest[P]{Session: s, Params: p} + default: + panic("bad session") + } +} + // Meta is additional metadata for requests, responses and other types. type Meta map[string]any @@ -335,6 +379,43 @@ func setProgressToken(p Params, pt any) { m[progressTokenKey] = pt } +// A Request is a method request with parameters and additional information, such as the session. +// Request is implemented by [*ClientRequest] and [*ServerRequest]. +type Request interface { + isRequest() + GetSession() Session + GetParams() Params +} + +// A ClientRequest is a request to a client. +type ClientRequest[P Params] struct { + Session *ClientSession + Params P +} + +// A ServerRequest is a request to a server. +type ServerRequest[P Params] struct { + Session *ServerSession + Params P +} + +func (*ClientRequest[P]) isRequest() {} +func (*ServerRequest[P]) isRequest() {} + +func (r *ClientRequest[P]) GetSession() Session { return r.Session } +func (r *ServerRequest[P]) GetSession() Session { return r.Session } + +func (r *ClientRequest[P]) GetParams() Params { return r.Params } +func (r *ServerRequest[P]) GetParams() Params { return r.Params } + +func serverRequestFor[P Params](s *ServerSession, p P) *ServerRequest[P] { + return &ServerRequest[P]{Session: s, Params: p} +} + +func clientRequestFor[P Params](s *ClientSession, p P) *ClientRequest[P] { + return &ClientRequest[P]{Session: s, Params: p} +} + // Params is a parameter (input) type for an MCP call or notification. type Params interface { // isParams discourages implementation of Params outside of this package. diff --git a/mcp/shared_test.go b/mcp/shared_test.go index 0aea1947..01d1eff7 100644 --- a/mcp/shared_test.go +++ b/mcp/shared_test.go @@ -23,7 +23,7 @@ func TestToolValidate(t *testing.T) { P *int `json:",omitempty"` } - dummyHandler := func(context.Context, *ServerSession, *CallToolParamsFor[req]) (*CallToolResultFor[any], error) { + dummyHandler := func(context.Context, *ServerRequest[*CallToolParamsFor[req]]) (*CallToolResultFor[any], error) { return nil, nil } @@ -73,8 +73,9 @@ func TestToolValidate(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = st.handler(context.Background(), nil, - &CallToolParamsFor[json.RawMessage]{Arguments: json.RawMessage(raw)}) + _, err = st.handler(context.Background(), &ServerRequest[*CallToolParamsFor[json.RawMessage]]{ + Params: &CallToolParamsFor[json.RawMessage]{Arguments: json.RawMessage(raw)}, + }) if err == nil && tt.want != "" { t.Error("got success, wanted failure") } @@ -102,12 +103,12 @@ func TestNilParamsHandling(t *testing.T) { type TestResult = *CallToolResultFor[string] // Simple test handler - testHandler := func(ctx context.Context, ss *ServerSession, params TestParams) (TestResult, error) { - result := "processed: " + params.Arguments.Name + testHandler := func(ctx context.Context, req *ServerRequest[TestParams]) (TestResult, error) { + result := "processed: " + req.Params.Arguments.Name return &CallToolResultFor[string]{StructuredContent: result}, nil } - methodInfo := newMethodInfo(testHandler, missingParamsOK) + methodInfo := newServerMethodInfo(testHandler, missingParamsOK) // Helper function to test that unmarshalParams doesn't panic and handles nil gracefully mustNotPanic := func(t *testing.T, rawMsg json.RawMessage, expectNil bool) Params { @@ -183,11 +184,11 @@ func TestNilParamsEdgeCases(t *testing.T) { } type TestParams = *CallToolParamsFor[TestArgs] - testHandler := func(ctx context.Context, ss *ServerSession, params TestParams) (*CallToolResultFor[string], error) { + testHandler := func(context.Context, *ServerRequest[TestParams]) (*CallToolResultFor[string], error) { return &CallToolResultFor[string]{StructuredContent: "test"}, nil } - methodInfo := newMethodInfo(testHandler, missingParamsOK) + methodInfo := newServerMethodInfo(testHandler, missingParamsOK) // These should fail normally, not be treated as nil params invalidCases := []json.RawMessage{ @@ -209,7 +210,7 @@ func TestNilParamsEdgeCases(t *testing.T) { // Test that methods without missingParamsOK flag properly reject nil params t.Run("reject_when_params_required", func(t *testing.T) { - methodInfoStrict := newMethodInfo(testHandler, 0) // No missingParamsOK flag + methodInfoStrict := newServerMethodInfo(testHandler, 0) // No missingParamsOK flag testCases := []struct { name string diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index cf1e75dc..b5dfdc56 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -18,10 +18,10 @@ type AddParams struct { X, Y int } -func Add(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[AddParams]) (*mcp.CallToolResultFor[any], error) { +func Add(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[AddParams]]) (*mcp.CallToolResultFor[any], error) { return &mcp.CallToolResultFor[any]{ Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("%d", params.Arguments.X+params.Arguments.Y)}, + &mcp.TextContent{Text: fmt.Sprintf("%d", req.Params.Arguments.X+req.Params.Arguments.Y)}, }, }, nil } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index ba7af9fa..fd1dc3e4 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -136,12 +136,12 @@ func TestClientReplay(t *testing.T) { serverReadyToKillProxy := make(chan struct{}) serverClosed := make(chan struct{}) server.AddTool(&Tool{Name: "multiMessageTool", InputSchema: &jsonschema.Schema{}}, - func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { + func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { go func() { bgCtx := context.Background() // Send the first two messages immediately. - ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg1"}) - ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"}) + req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg1"}) + req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"}) // Signal the test that it can now kill the proxy. close(serverReadyToKillProxy) @@ -149,8 +149,8 @@ func TestClientReplay(t *testing.T) { // These messages should be queued for replay by the server after // the client's connection drops. - ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg3"}) - ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"}) + req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg3"}) + req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"}) }() return &CallToolResult{}, nil }) @@ -170,8 +170,8 @@ func TestClientReplay(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() client := NewClient(testImpl, &ClientOptions{ - ProgressNotificationHandler: func(ctx context.Context, cc *ClientSession, params *ProgressNotificationParams) { - notifications <- params.Message + ProgressNotificationHandler: func(ctx context.Context, req *ClientRequest[*ProgressNotificationParams]) { + notifications <- req.Params.Message }, }) clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: proxy.URL}, nil) @@ -236,9 +236,10 @@ func TestServerInitiatedSSE(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - client := NewClient(testImpl, &ClientOptions{ToolListChangedHandler: func(ctx context.Context, cc *ClientSession, params *ToolListChangedParams) { - notifications <- "toolListChanged" - }, + client := NewClient(testImpl, &ClientOptions{ + ToolListChangedHandler: func(context.Context, *ClientRequest[*ToolListChangedParams]) { + notifications <- "toolListChanged" + }, }) clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil) if err != nil { @@ -246,7 +247,7 @@ func TestServerInitiatedSSE(t *testing.T) { } defer clientSession.Close() server.AddTool(&Tool{Name: "testTool", InputSchema: &jsonschema.Schema{}}, - func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { + func(context.Context, *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { return &CallToolResult{}, nil }) receivedNotifications := readNotifications(t, ctx, notifications, 1) @@ -509,9 +510,9 @@ func TestStreamableServerTransport(t *testing.T) { // Create a server containing a single tool, which runs the test tool // behavior, if any. server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) - AddTool(server, &Tool{Name: "tool"}, func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[any], error) { + AddTool(server, &Tool{Name: "tool"}, func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[any]]) (*CallToolResultFor[any], error) { if test.tool != nil { - test.tool(t, ctx, ss) + test.tool(t, ctx, req.Session) } return &CallToolResultFor[any]{}, nil }) @@ -822,14 +823,15 @@ func TestEventID(t *testing.T) { }) } } + func TestStreamableStateless(t *testing.T) { // Test stateless mode behavior ctx := context.Background() // This version of sayHi doesn't make a ping request (we can't respond to // that request from our client). - sayHi := func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[hiParams]) (*CallToolResultFor[any], error) { - return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + params.Arguments.Name}}}, nil + sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResultFor[any], error) { + return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil } server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) diff --git a/mcp/tool.go b/mcp/tool.go index 234cd659..15f17e11 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -20,10 +20,12 @@ import ( type ToolHandler = ToolHandlerFor[map[string]any, any] // A ToolHandlerFor handles a call to tools/call with typed arguments and results. -type ToolHandlerFor[In, Out any] func(context.Context, *ServerSession, *CallToolParamsFor[In]) (*CallToolResultFor[Out], error) +type ToolHandlerFor[In, Out any] func(context.Context, *ServerRequest[*CallToolParamsFor[In]]) (*CallToolResultFor[Out], error) // A rawToolHandler is like a ToolHandler, but takes the arguments as as json.RawMessage. -type rawToolHandler = func(context.Context, *ServerSession, *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) +// Second arg is *Request[*ServerSession, *CallToolParamsFor[json.RawMessage]], but that creates +// a cycle. +type rawToolHandler = func(context.Context, any) (*CallToolResult, error) // A serverTool is a tool definition that is bound to a tool handler. type serverTool struct { @@ -48,20 +50,25 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool } } - st.handler = func(ctx context.Context, ss *ServerSession, rparams *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) { + st.handler = func(ctx context.Context, areq any) (*CallToolResult, error) { + req := areq.(*ServerRequest[*CallToolParamsFor[json.RawMessage]]) var args In - if rparams.Arguments != nil { - if err := unmarshalSchema(rparams.Arguments, st.inputResolved, &args); err != nil { + if req.Params.Arguments != nil { + if err := unmarshalSchema(req.Params.Arguments, st.inputResolved, &args); err != nil { return nil, err } } // TODO(jba): future-proof this copy. params := &CallToolParamsFor[In]{ - Meta: rparams.Meta, - Name: rparams.Name, + Meta: req.Params.Meta, + Name: req.Params.Name, Arguments: args, } - res, err := h(ctx, ss, params) + // TODO(jba): improve copy + res, err := h(ctx, &ServerRequest[*CallToolParamsFor[In]]{ + Session: req.Session, + Params: params, + }) // TODO(rfindley): investigate why server errors are embedded in this strange way, // rather than returned as jsonrpc2 server errors. if err != nil { diff --git a/mcp/tool_test.go b/mcp/tool_test.go index 52cac9fc..609536cc 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -16,7 +16,7 @@ import ( ) // testToolHandler is used for type inference in TestNewServerTool. -func testToolHandler[In, Out any](context.Context, *ServerSession, *CallToolParamsFor[In]) (*CallToolResultFor[Out], error) { +func testToolHandler[In, Out any](context.Context, *ServerRequest[*CallToolParamsFor[In]]) (*CallToolResultFor[Out], error) { panic("not implemented") } From 87f222477b31e542d33283f71358f829eb6a996b Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Wed, 13 Aug 2025 10:11:01 -0400 Subject: [PATCH 094/221] mcp: add a test for streamable sampling during a tool call Add a test that attempts (and fails) to reproduce the bug reported in issue #285. For #285 --- mcp/streamable_test.go | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index fd1dc3e4..25dd224e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -40,6 +40,25 @@ func TestStreamableTransports(t *testing.T) { // 1. Create a server with a simple "greet" tool. server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + // Test that we can make sampling requests during tool handling. + // + // Try this on both the request context and a background context, so + // that messages may be delivered on either the POST or GET connection. + for _, ctx := range map[string]context.Context{ + "request context": ctx, + "background context": context.Background(), + } { + res, err := req.Session.CreateMessage(ctx, &CreateMessageParams{}) + if err != nil { + return nil, err + } + if g, w := res.Model, "aModel"; g != w { + return nil, fmt.Errorf("got %q, want %q", g, w) + } + } + return &CallToolResultFor[any]{}, nil + }) // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a // cookie-checking middleware. @@ -81,7 +100,11 @@ func TestStreamableTransports(t *testing.T) { Endpoint: httpServer.URL, HTTPClient: httpClient, } - client := NewClient(testImpl, nil) + client := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil + }, + }) session, err := client.Connect(ctx, transport, nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) @@ -119,6 +142,19 @@ func TestStreamableTransports(t *testing.T) { if diff := cmp.Diff(want, got); diff != "" { t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff) } + + // 6. Run the "sampling" tool and verify that the streamable server can + // call tools. + result, err := session.CallTool(ctx, &CallToolParams{ + Name: "sample", + Arguments: map[string]any{}, + }) + if err != nil { + t.Fatal(err) + } + if result.IsError { + t.Fatalf("tool failed: %s", result.Content[0].(*TextContent).Text) + } }) } } From a834f3cf9acd0b79181cdb9a95ebe36755194ecc Mon Sep 17 00:00:00 2001 From: Riley Norton <46510750+RileyNorton@users.noreply.github.com> Date: Wed, 13 Aug 2025 14:12:13 -0700 Subject: [PATCH 095/221] mcp: allow correct 'Accept' header values Allow other header values for the 'Accept' header that imply application/json and text/event-stream for requests to the Streamable HTTP Transport. Fixes #290 Unit tests here seemed like unnecessary complexity and wouldn't fit cleanly with the style of the `streamable_test.go` tests. Happy to add those if desired though. --- mcp/streamable.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 048b99aa..0d417bc8 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -98,9 +98,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque var jsonOK, streamOK bool for _, c := range accept { switch strings.TrimSpace(c) { - case "application/json": + case "application/json", "application/*": + jsonOK = true + case "text/event-stream", "text/*": + streamOK = true + case "*/*": jsonOK = true - case "text/event-stream": streamOK = true } } From bb6dadecca24ba53195559f05ef40904d48baeff Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 13 Aug 2025 15:37:56 +0000 Subject: [PATCH 096/221] mcp: fix cancellation for HTTP transport In #202, I added the checkRequest helper to validate incoming requests, and invoked it in the stremable transports to preemptively reject invalid HTTP requests, so that a jsonrpc error could be translated to an HTTP error. However, this introduced a bug: since cancellation was handled in the jsonrpc2 layer, we never had to validate it in the mcp layer, and therefore never added methodInfo. As a result, it was reported as an invalid request in the http layer. Add a test, and a fix. The simplest fix was to create stubs that are placeholders for cancellation. This was discovered in the course of investigating #285. --- mcp/client.go | 10 +++++++++ mcp/mcp_test.go | 5 ++--- mcp/server.go | 10 +++++++++ mcp/streamable.go | 20 ++++++++--------- mcp/streamable_test.go | 51 ++++++++++++++++++++++++++++++++---------- mcp/transport.go | 2 +- 6 files changed, 72 insertions(+), 26 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index b0db1d64..88eea7da 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -305,6 +305,7 @@ var clientMethodInfos = map[string]methodInfo{ methodPing: newClientMethodInfo(clientSessionMethod((*ClientSession).ping), missingParamsOK), methodListRoots: newClientMethodInfo(clientMethod((*Client).listRoots), missingParamsOK), methodCreateMessage: newClientMethodInfo(clientMethod((*Client).createMessage), 0), + notificationCancelled: newClientMethodInfo(clientSessionMethod((*ClientSession).cancel), notification|missingParamsOK), notificationToolListChanged: newClientMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK), notificationPromptListChanged: newClientMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK), notificationResourceListChanged: newClientMethodInfo(clientMethod((*Client).callResourceChangedHandler), notification|missingParamsOK), @@ -344,6 +345,15 @@ func (*ClientSession) ping(context.Context, *PingParams) (*emptyResult, error) { return &emptyResult{}, nil } +// cancel is a placeholder: cancellation is handled the jsonrpc2 package. +// +// It should never be invoked in practice because cancellation is preempted, +// but having its signature here facilitates the construction of methodInfo +// that can be used to validate incoming cancellation notifications. +func (*ClientSession) cancel(context.Context, *CancelledParams) (Result, error) { + return nil, nil +} + func newClientRequest[P Params](cs *ClientSession, params P) *ClientRequest[P] { return &ClientRequest[P]{Session: cs, Params: params} } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 58b0377e..66ad7e0e 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -646,8 +646,7 @@ func TestCancellation(t *testing.T) { start = make(chan struct{}) cancelled = make(chan struct{}, 1) // don't block the request ) - - slowRequest := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { + slowRequest := func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { start <- struct{}{} select { case <-ctx.Done(): @@ -658,7 +657,7 @@ func TestCancellation(t *testing.T) { return nil, nil } _, cs := basicConnection(t, func(s *Server) { - s.AddTool(&Tool{Name: "slow", InputSchema: &jsonschema.Schema{}}, slowRequest) + AddTool(s, &Tool{Name: "slow"}, slowRequest) }) defer cs.Close() diff --git a/mcp/server.go b/mcp/server.go index e39372dc..5b7538a1 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -760,6 +760,7 @@ var serverMethodInfos = map[string]methodInfo{ methodSetLevel: newServerMethodInfo(serverSessionMethod((*ServerSession).setLevel), 0), methodSubscribe: newServerMethodInfo(serverMethod((*Server).subscribe), 0), methodUnsubscribe: newServerMethodInfo(serverMethod((*Server).unsubscribe), 0), + notificationCancelled: newServerMethodInfo(serverSessionMethod((*ServerSession).cancel), notification|missingParamsOK), notificationInitialized: newServerMethodInfo(serverSessionMethod((*ServerSession).initialized), notification|missingParamsOK), notificationRootsListChanged: newServerMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK), notificationProgress: newServerMethodInfo(serverSessionMethod((*ServerSession).callProgressNotificationHandler), notification), @@ -838,6 +839,15 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error return &emptyResult{}, nil } +// cancel is a placeholder: cancellation is handled the jsonrpc2 package. +// +// It should never be invoked in practice because cancellation is preempted, +// but having its signature here facilitates the construction of methodInfo +// that can be used to validate incoming cancellation notifications. +func (ss *ServerSession) cancel(context.Context, *CancelledParams) (Result, error) { + return nil, nil +} + func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*emptyResult, error) { ss.updateState(func(state *ServerSessionState) { state.LogLevel = params.Level diff --git a/mcp/streamable.go b/mcp/streamable.go index 0d417bc8..5692b985 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -118,12 +118,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } - var session *StreamableServerTransport + var transport *StreamableServerTransport if id := req.Header.Get(sessionIDHeader); id != "" { h.mu.Lock() - session, _ = h.transports[id] + transport, _ = h.transports[id] h.mu.Unlock() - if session == nil { + if transport == nil { http.Error(w, "session not found", http.StatusNotFound) return } @@ -132,22 +132,22 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // TODO(rfindley): simplify the locking so that each request has only one // critical section. if req.Method == http.MethodDelete { - if session == nil { + if transport == nil { // => Mcp-Session-Id was not set; else we'd have returned NotFound above. http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) return } h.mu.Lock() - delete(h.transports, session.SessionID) + delete(h.transports, transport.SessionID) h.mu.Unlock() - session.connection.Close() + transport.connection.Close() w.WriteHeader(http.StatusNoContent) return } switch req.Method { case http.MethodPost, http.MethodGet: - if req.Method == http.MethodGet && session == nil { + if req.Method == http.MethodGet && transport == nil { http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed) return } @@ -157,7 +157,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } - if session == nil { + if transport == nil { server := h.getServer(req) if server == nil { // The getServer argument to NewStreamableHTTPHandler returned nil. @@ -194,10 +194,10 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque h.transports[s.SessionID] = s h.mu.Unlock() } - session = s + transport = s } - session.ServeHTTP(w, req) + transport.ServeHTTP(w, req) } // StreamableServerTransportOptions configures the stramable server transport. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 25dd224e..ca5e5a5c 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -37,9 +37,26 @@ func TestStreamableTransports(t *testing.T) { for _, useJSON := range []bool{false, true} { t.Run(fmt.Sprintf("JSONResponse=%v", useJSON), func(t *testing.T) { - // 1. Create a server with a simple "greet" tool. + // Create a server with some simple tools. server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + // The "hang" tool checks that context cancellation is propagated. + // It hangs until the context is cancelled. + var ( + start = make(chan struct{}) + cancelled = make(chan struct{}, 1) // don't block the request + ) + hang := func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + start <- struct{}{} + select { + case <-ctx.Done(): + cancelled <- struct{}{} + case <-time.After(5 * time.Second): + return nil, nil + } + return nil, nil + } + AddTool(server, &Tool{Name: "hang"}, hang) AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { // Test that we can make sampling requests during tool handling. // @@ -60,7 +77,7 @@ func TestStreamableTransports(t *testing.T) { return &CallToolResultFor[any]{}, nil }) - // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a + // Start an httptest.Server with the StreamableHTTPHandler, wrapped in a // cookie-checking middleware. handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ jsonResponse: useJSON, @@ -84,7 +101,7 @@ func TestStreamableTransports(t *testing.T) { })) defer httpServer.Close() - // 3. Create a client and connect it to the server using our StreamableClientTransport. + // Create a client and connect it to the server using our StreamableClientTransport. // Check that all requests honor a custom client. jar, err := cookiejar.New(nil) if err != nil { @@ -117,10 +134,13 @@ func TestStreamableTransports(t *testing.T) { if g, w := session.mcpConn.(*streamableClientConn).initializedResult.ProtocolVersion, latestProtocolVersion; g != w { t.Fatalf("got protocol version %q, want %q", g, w) } - // 4. The client calls the "greet" tool. + + // Verify the behavior of various tools. + + // The "greet" tool should just work. params := &CallToolParams{ Name: "greet", - Arguments: map[string]any{"name": "streamy"}, + Arguments: map[string]any{"name": "foo"}, } got, err := session.CallTool(ctx, params) if err != nil { @@ -132,19 +152,26 @@ func TestStreamableTransports(t *testing.T) { if g, w := lastHeader.Get(protocolVersionHeader), latestProtocolVersion; g != w { t.Errorf("got protocol version header %q, want %q", g, w) } - - // 5. Verify that the correct response is received. want := &CallToolResult{ - Content: []Content{ - &TextContent{Text: "hi streamy"}, - }, + Content: []Content{&TextContent{Text: "hi foo"}}, } if diff := cmp.Diff(want, got); diff != "" { t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff) } - // 6. Run the "sampling" tool and verify that the streamable server can - // call tools. + // The "hang" tool should be cancellable. + ctx2, cancel := context.WithCancel(context.Background()) + go session.CallTool(ctx2, &CallToolParams{Name: "hang"}) + <-start + cancel() + select { + case <-cancelled: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for cancellation") + } + + // The "sampling" tool should be able to issue sampling requests during + // tool operation. result, err := session.CallTool(ctx, &CallToolParams{ Name: "sample", Arguments: map[string]any{}, diff --git a/mcp/transport.go b/mcp/transport.go index 76b79986..6d25de33 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -171,7 +171,7 @@ type canceller struct { // Preempt implements [jsonrpc2.Preempter]. func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result any, err error) { - if req.Method == "notifications/cancelled" { + if req.Method == notificationCancelled { var params CancelledParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { return nil, err From 16f2857700f22948fb90d4dcf61e3258c389a326 Mon Sep 17 00:00:00 2001 From: Kamdyn Shaeffer <136033129+KamdynS@users.noreply.github.com> Date: Thu, 14 Aug 2025 13:38:31 -0400 Subject: [PATCH 097/221] mcp/tool: duplicate tools should not error (#295) This CL causes the dup to fail. --- mcp/mcp_test.go | 35 +++++++++++++++++++++++++++++++++++ mcp/server.go | 9 +++++++++ 2 files changed, 44 insertions(+) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 66ad7e0e..d04235b8 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -937,4 +937,39 @@ func TestKeepAliveFailure(t *testing.T) { t.Errorf("expected connection to be closed by keepalive, but it wasn't. Last error: %v", err) } +func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { + // Adding the same tool pointer twice should not panic and should not + // produce duplicates in the server's tool list. + _, cs := basicConnection(t, func(s *Server) { + // Use two distinct Tool instances with the same name but different + // descriptions to ensure the second replaces the first + // This case was written specifically to reproduce a bug where duplicate tools where causing jsonschema errors + t1 := &Tool{Name: "dup", Description: "first", InputSchema: &jsonschema.Schema{}} + t2 := &Tool{Name: "dup", Description: "second", InputSchema: &jsonschema.Schema{}} + s.AddTool(t1, nopHandler) + s.AddTool(t2, nopHandler) + }) + defer cs.Close() + + ctx := context.Background() + res, err := cs.ListTools(ctx, nil) + if err != nil { + t.Fatal(err) + } + var count int + var gotDesc string + for _, tt := range res.Tools { + if tt.Name == "dup" { + count++ + gotDesc = tt.Description + } + } + if count != 1 { + t.Fatalf("expected exactly one 'dup' tool, got %d", count) + } + if gotDesc != "second" { + t.Fatalf("expected replaced tool to have description %q, got %q", "second", gotDesc) + } +} + var testImpl = &Implementation{Name: "test", Version: "v1.0.0"} diff --git a/mcp/server.go b/mcp/server.go index 5b7538a1..ed4ec720 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -171,6 +171,15 @@ func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { func addToolErr[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) (err error) { defer util.Wrapf(&err, "adding tool %q", t.Name) + // If the exact same Tool pointer has already been registered under this name, + // avoid rebuilding schemas and re-registering. This prevents duplicate + // registration from causing errors (and unnecessary work). + s.mu.Lock() + if existing, ok := s.tools.get(t.Name); ok && existing.tool == t { + s.mu.Unlock() + return nil + } + s.mu.Unlock() st, err := newServerTool(t, h) if err != nil { return err From cfa5c1d8028327a8ea088580abb0982cb39e4eb6 Mon Sep 17 00:00:00 2001 From: Tim Gossett Date: Thu, 14 Aug 2025 14:27:25 -0400 Subject: [PATCH 098/221] mcp: check that UnmarshalJSON methods for Content don't panic on nil Introduces a new test file content_nil_test.go which verifies that UnmarshalJSON methods for various Content types do not panic when unmarshaling onto nil pointers. Adds a nil check in contentFromWire function to guard against a nil wire.Content parameter. Tests cover different scenarios, including valid and invalid content types, as well as cases with empty or missing content fields. For #205 --- mcp/content.go | 3 + mcp/content_nil_test.go | 224 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 227 insertions(+) create mode 100644 mcp/content_nil_test.go diff --git a/mcp/content.go b/mcp/content.go index 8bf75f0f..f8777154 100644 --- a/mcp/content.go +++ b/mcp/content.go @@ -252,6 +252,9 @@ func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, e } func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error) { + if wire == nil { + return nil, fmt.Errorf("content wire is nil") + } if allow != nil && !allow[wire.Type] { return nil, fmt.Errorf("invalid content type %q", wire.Type) } diff --git a/mcp/content_nil_test.go b/mcp/content_nil_test.go new file mode 100644 index 00000000..c803ba69 --- /dev/null +++ b/mcp/content_nil_test.go @@ -0,0 +1,224 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file contains tests to verify that UnmarshalJSON methods for Content types +// don't panic when unmarshaling onto nil pointers, as requested in GitHub issue #205. +// +// NOTE: The contentFromWire function has been fixed to handle nil wire.Content +// gracefully by returning an error instead of panicking. + +package mcp_test + +import ( + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func TestContentUnmarshalNil(t *testing.T) { + tests := []struct { + name string + json string + content interface{} + want interface{} + }{ + { + name: "CallToolResult nil Content", + json: `{"content":[{"type":"text","text":"hello"}]}`, + content: &mcp.CallToolResult{}, + want: &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "hello"}}}, + }, + { + name: "CreateMessageResult nil Content", + json: `{"content":{"type":"text","text":"hello"},"model":"test","role":"user"}`, + content: &mcp.CreateMessageResult{}, + want: &mcp.CreateMessageResult{Content: &mcp.TextContent{Text: "hello"}, Model: "test", Role: "user"}, + }, + { + name: "PromptMessage nil Content", + json: `{"content":{"type":"text","text":"hello"},"role":"user"}`, + content: &mcp.PromptMessage{}, + want: &mcp.PromptMessage{Content: &mcp.TextContent{Text: "hello"}, Role: "user"}, + }, + { + name: "SamplingMessage nil Content", + json: `{"content":{"type":"text","text":"hello"},"role":"user"}`, + content: &mcp.SamplingMessage{}, + want: &mcp.SamplingMessage{Content: &mcp.TextContent{Text: "hello"}, Role: "user"}, + }, + { + name: "CallToolResultFor nil Content", + json: `{"content":[{"type":"text","text":"hello"}]}`, + content: &mcp.CallToolResultFor[string]{}, + want: &mcp.CallToolResultFor[string]{Content: []mcp.Content{&mcp.TextContent{Text: "hello"}}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test that unmarshaling doesn't panic on nil Content fields + defer func() { + if r := recover(); r != nil { + t.Errorf("UnmarshalJSON panicked: %v", r) + } + }() + + err := json.Unmarshal([]byte(tt.json), tt.content) + if err != nil { + t.Errorf("UnmarshalJSON failed: %v", err) + } + + // Verify that the Content field was properly populated + if cmp.Diff(tt.want, tt.content) != "" { + t.Errorf("Content is not equal: %v", cmp.Diff(tt.content, tt.content)) + } + }) + } +} + +func TestContentUnmarshalNilWithDifferentTypes(t *testing.T) { + tests := []struct { + name string + json string + content interface{} + expectError bool + }{ + { + name: "ImageContent", + json: `{"content":{"type":"image","mimeType":"image/png","data":"YTFiMmMz"}}`, + content: &mcp.CreateMessageResult{}, + expectError: false, + }, + { + name: "AudioContent", + json: `{"content":{"type":"audio","mimeType":"audio/wav","data":"YTFiMmMz"}}`, + content: &mcp.CreateMessageResult{}, + expectError: false, + }, + { + name: "ResourceLink", + json: `{"content":{"type":"resource_link","uri":"file:///test","name":"test"}}`, + content: &mcp.CreateMessageResult{}, + expectError: true, // CreateMessageResult only allows text, image, audio + }, + { + name: "EmbeddedResource", + json: `{"content":{"type":"resource","resource":{"uri":"file://test","text":"test"}}}`, + content: &mcp.CreateMessageResult{}, + expectError: true, // CreateMessageResult only allows text, image, audio + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test that unmarshaling doesn't panic on nil Content fields + defer func() { + if r := recover(); r != nil { + t.Errorf("UnmarshalJSON panicked: %v", r) + } + }() + + err := json.Unmarshal([]byte(tt.json), tt.content) + if tt.expectError && err == nil { + t.Error("Expected error but got none") + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Verify that the Content field was properly populated for successful cases + if !tt.expectError { + if result, ok := tt.content.(*mcp.CreateMessageResult); ok { + if result.Content == nil { + t.Error("CreateMessageResult.Content was not populated") + } + } + } + }) + } +} + +func TestContentUnmarshalNilWithEmptyContent(t *testing.T) { + tests := []struct { + name string + json string + content interface{} + expectError bool + }{ + { + name: "Empty Content array", + json: `{"content":[]}`, + content: &mcp.CallToolResult{}, + expectError: false, + }, + { + name: "Missing Content field", + json: `{"model":"test","role":"user"}`, + content: &mcp.CreateMessageResult{}, + expectError: true, // Content field is required for CreateMessageResult + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test that unmarshaling doesn't panic on nil Content fields + defer func() { + if r := recover(); r != nil { + t.Errorf("UnmarshalJSON panicked: %v", r) + } + }() + + err := json.Unmarshal([]byte(tt.json), tt.content) + if tt.expectError && err == nil { + t.Error("Expected error but got none") + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + } +} + +func TestContentUnmarshalNilWithInvalidContent(t *testing.T) { + tests := []struct { + name string + json string + content interface{} + expectError bool + }{ + { + name: "Invalid content type", + json: `{"content":{"type":"invalid","text":"hello"}}`, + content: &mcp.CreateMessageResult{}, + expectError: true, + }, + { + name: "Missing type field", + json: `{"content":{"text":"hello"}}`, + content: &mcp.CreateMessageResult{}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test that unmarshaling doesn't panic on nil Content fields + defer func() { + if r := recover(); r != nil { + t.Errorf("UnmarshalJSON panicked: %v", r) + } + }() + + err := json.Unmarshal([]byte(tt.json), tt.content) + if tt.expectError && err == nil { + t.Error("Expected error but got none") + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + } +} From 54c99810f6f0c3e0e0b43cfe30f85fcaf18db23f Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 14 Aug 2025 16:15:53 +0000 Subject: [PATCH 099/221] mcp: add a test for streamable accept headers Add a test following up on the fix for #290. --- mcp/streamable_test.go | 245 ++++++++++++++++++++++++----------------- 1 file changed, 143 insertions(+), 102 deletions(-) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index ca5e5a5c..e9bb54e2 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -345,30 +345,6 @@ func TestStreamableServerTransport(t *testing.T) { // faking the behavior of a streamable client using a sequence of HTTP // requests. - // A step is a single step in the tests below, consisting of a request payload - // and expected response. - type step struct { - // If OnRequest is > 0, this step only executes after a request with the - // given ID is received. - // - // All OnRequest steps must occur before the step that creates the request. - // - // To avoid tests hanging when there's a bug, it's expected that this - // request is received in the course of a *synchronous* request to the - // server (otherwise, we wouldn't be able to terminate the test without - // analyzing a dependency graph). - OnRequest int64 - // If set, Async causes the step to run asynchronously to other steps. - // Redundant with OnRequest: all OnRequest steps are asynchronous. - Async bool - - Method string // HTTP request method - Send []jsonrpc.Message // messages to send - CloseAfter int // if nonzero, close after receiving this many messages - StatusCode int // expected status code - Recv []jsonrpc.Message // expected messages to receive - } - // JSON-RPC message constructors. req := func(id int64, method string, params any) *jsonrpc.Request { r := &jsonrpc.Request{ @@ -399,33 +375,67 @@ func TestStreamableServerTransport(t *testing.T) { ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, }, nil) initializedMsg := req(0, notificationInitialized, &InitializedParams{}) - initialize := step{ - Method: "POST", - Send: []jsonrpc.Message{initReq}, - StatusCode: http.StatusOK, - Recv: []jsonrpc.Message{initResp}, + initialize := streamableRequest{ + method: "POST", + messages: []jsonrpc.Message{initReq}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{initResp}, } - initialized := step{ - Method: "POST", - Send: []jsonrpc.Message{initializedMsg}, - StatusCode: http.StatusAccepted, + initialized := streamableRequest{ + method: "POST", + messages: []jsonrpc.Message{initializedMsg}, + wantStatusCode: http.StatusAccepted, } tests := []struct { name string tool func(*testing.T, context.Context, *ServerSession) - steps []step + steps []streamableRequest }{ { name: "basic", - steps: []step{ + steps: []streamableRequest{ + initialize, + initialized, + { + method: "POST", + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{}, nil)}, + }, + }, + }, + { + name: "accept headers", + steps: []streamableRequest{ initialize, initialized, + // Test various accept headers. + { + method: "POST", + accept: []string{"text/plain", "application/*"}, + messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusBadRequest, // missing text/event-stream + }, + { + method: "POST", + accept: []string{"text/event-stream"}, + messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusBadRequest, // missing application/json + }, + { + method: "POST", + accept: []string{"text/plain", "*/*"}, + messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, + }, { - Method: "POST", - Send: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, - StatusCode: http.StatusOK, - Recv: []jsonrpc.Message{resp(2, &CallToolResult{}, nil)}, + method: "POST", + accept: []string{"text/*, application/*"}, + messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, }, }, }, @@ -437,16 +447,16 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Notify failed: %v", err) } }, - steps: []step{ + steps: []streamableRequest{ initialize, initialized, { - Method: "POST", - Send: []jsonrpc.Message{ + method: "POST", + messages: []jsonrpc.Message{ req(2, "tools/call", &CallToolParams{Name: "tool"}), }, - StatusCode: http.StatusOK, - Recv: []jsonrpc.Message{ + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{ req(0, "notifications/progress", &ProgressNotificationParams{}), resp(2, &CallToolResult{}, nil), }, @@ -461,24 +471,24 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Call failed: %v", err) } }, - steps: []step{ + steps: []streamableRequest{ initialize, initialized, { - Method: "POST", - OnRequest: 1, - Send: []jsonrpc.Message{ + method: "POST", + onRequest: 1, + messages: []jsonrpc.Message{ resp(1, &ListRootsResult{}, nil), }, - StatusCode: http.StatusAccepted, + wantStatusCode: http.StatusAccepted, }, { - Method: "POST", - Send: []jsonrpc.Message{ + method: "POST", + messages: []jsonrpc.Message{ req(2, "tools/call", &CallToolParams{Name: "tool"}), }, - StatusCode: http.StatusOK, - Recv: []jsonrpc.Message{ + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{ req(1, "roots/list", &ListRootsParams{}), resp(2, &CallToolResult{}, nil), }, @@ -502,34 +512,34 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Notify failed: %v", err) } }, - steps: []step{ + steps: []streamableRequest{ initialize, initialized, { - Method: "POST", - OnRequest: 1, - Send: []jsonrpc.Message{ + method: "POST", + onRequest: 1, + messages: []jsonrpc.Message{ resp(1, &ListRootsResult{}, nil), }, - StatusCode: http.StatusAccepted, + wantStatusCode: http.StatusAccepted, }, { - Method: "GET", - Async: true, - StatusCode: http.StatusOK, - CloseAfter: 2, - Recv: []jsonrpc.Message{ + method: "GET", + async: true, + wantStatusCode: http.StatusOK, + closeAfter: 2, + wantMessages: []jsonrpc.Message{ req(0, "notifications/progress", &ProgressNotificationParams{}), req(1, "roots/list", &ListRootsParams{}), }, }, { - Method: "POST", - Send: []jsonrpc.Message{ + method: "POST", + messages: []jsonrpc.Message{ req(2, "tools/call", &CallToolParams{Name: "tool"}), }, - StatusCode: http.StatusOK, - Recv: []jsonrpc.Message{ + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{ resp(2, &CallToolResult{}, nil), }, }, @@ -537,30 +547,30 @@ func TestStreamableServerTransport(t *testing.T) { }, { name: "errors", - steps: []step{ + steps: []streamableRequest{ { - Method: "PUT", - StatusCode: http.StatusMethodNotAllowed, + method: "PUT", + wantStatusCode: http.StatusMethodNotAllowed, }, { - Method: "DELETE", - StatusCode: http.StatusBadRequest, + method: "DELETE", + wantStatusCode: http.StatusBadRequest, }, { - Method: "POST", - Send: []jsonrpc.Message{req(1, "notamethod", nil)}, - StatusCode: http.StatusBadRequest, // notamethod is an invalid method + method: "POST", + messages: []jsonrpc.Message{req(1, "notamethod", nil)}, + wantStatusCode: http.StatusBadRequest, // notamethod is an invalid method }, { - Method: "POST", - Send: []jsonrpc.Message{req(0, "tools/call", &CallToolParams{Name: "tool"})}, - StatusCode: http.StatusBadRequest, // tools/call must have an ID + method: "POST", + messages: []jsonrpc.Message{req(0, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusBadRequest, // tools/call must have an ID }, { - Method: "POST", - Send: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, - StatusCode: http.StatusOK, - Recv: []jsonrpc.Message{resp(2, nil, &jsonrpc2.WireError{ + method: "POST", + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(2, nil, &jsonrpc2.WireError{ Message: `method "tools/call" is invalid during session initialization`, })}, }, @@ -594,8 +604,8 @@ func TestStreamableServerTransport(t *testing.T) { var mu sync.Mutex blocks := make(map[int64]chan struct{}) for _, step := range test.steps { - if step.OnRequest > 0 { - blocks[step.OnRequest] = make(chan struct{}) + if step.onRequest > 0 { + blocks[step.onRequest] = make(chan struct{}) } } @@ -609,16 +619,16 @@ func TestStreamableServerTransport(t *testing.T) { sessionID.Store("") // doStep executes a single step. - doStep := func(t *testing.T, step step) { - if step.OnRequest > 0 { + doStep := func(t *testing.T, step streamableRequest) { + if step.onRequest > 0 { // Block the step until we've received the server->client request. mu.Lock() - block := blocks[step.OnRequest] + block := blocks[step.onRequest] mu.Unlock() select { case <-block: case <-syncRequestsDone: - t.Errorf("after all sync requests are complete, request still blocked on %d", step.OnRequest) + t.Errorf("after all sync requests are complete, request still blocked on %d", step.onRequest) return } } @@ -650,14 +660,13 @@ func TestStreamableServerTransport(t *testing.T) { mu.Unlock() } got = append(got, m) - if step.CloseAfter > 0 && len(got) == step.CloseAfter { + if step.closeAfter > 0 && len(got) == step.closeAfter { cancel() } } }() - gotSessionID, gotStatusCode, err := streamingRequest(ctx, - httpServer.URL, sessionID.Load().(string), step.Method, step.Send, out) + gotSessionID, gotStatusCode, err := step.do(ctx, httpServer.URL, sessionID.Load().(string), out) // Don't fail on cancelled requests: error (if any) is handled // elsewhere. @@ -665,13 +674,13 @@ func TestStreamableServerTransport(t *testing.T) { t.Fatal(err) } - if gotStatusCode != step.StatusCode { - t.Errorf("got status %d, want %d", gotStatusCode, step.StatusCode) + if gotStatusCode != step.wantStatusCode { + t.Errorf("got status %d, want %d", gotStatusCode, step.wantStatusCode) } wg.Wait() transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) - if diff := cmp.Diff(step.Recv, got, transform); diff != "" { + if diff := cmp.Diff(step.wantMessages, got, transform); diff != "" { t.Errorf("received unexpected messages (-want +got):\n%s", diff) } sessionID.CompareAndSwap("", gotSessionID) @@ -679,7 +688,7 @@ func TestStreamableServerTransport(t *testing.T) { var wg sync.WaitGroup for _, step := range test.steps { - if step.Async || step.OnRequest > 0 { + if step.async || step.onRequest > 0 { wg.Add(1) go func() { defer wg.Done() @@ -699,6 +708,33 @@ func TestStreamableServerTransport(t *testing.T) { } } +// A streamableRequest describes a single streamable HTTP request, consisting +// of a request payload and expected response. +type streamableRequest struct { + // If onRequest is > 0, this step only executes after a request with the + // given ID is received. + // + // All onRequest steps must occur before the step that creates the request. + // + // To avoid tests hanging when there's a bug, it's expected that this + // request is received in the course of a *synchronous* request to the + // server (otherwise, we wouldn't be able to terminate the test without + // analyzing a dependency graph). + onRequest int64 + // If set, async causes the step to run asynchronously to other steps. + // Redundant with OnRequest: all OnRequest steps are asynchronous. + async bool + + // Request attributes + method string // HTTP request method + accept []string // if non-empty, the Accept header to use; otherwise the default header is used + messages []jsonrpc.Message // messages to send + + closeAfter int // if nonzero, close after receiving this many messages + wantStatusCode int // expected status code + wantMessages []jsonrpc.Message // expected messages to receive +} + // streamingRequest makes a request to the given streamable server with the // given url, sessionID, and method. // @@ -712,19 +748,19 @@ func TestStreamableServerTransport(t *testing.T) { // Returns the sessionID and http status code from the response. If an error is // returned, sessionID and status code may still be set if the error occurs // after the response headers have been received. -func streamingRequest(ctx context.Context, serverURL, sessionID, method string, in []jsonrpc.Message, out chan<- jsonrpc.Message) (string, int, error) { +func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, out chan<- jsonrpc.Message) (string, int, error) { defer close(out) var body []byte - if len(in) == 1 { - data, err := jsonrpc2.EncodeMessage(in[0]) + if len(s.messages) == 1 { + data, err := jsonrpc2.EncodeMessage(s.messages[0]) if err != nil { return "", 0, fmt.Errorf("encoding message: %w", err) } body = data } else { var rawMsgs []json.RawMessage - for _, msg := range in { + for _, msg := range s.messages { data, err := jsonrpc2.EncodeMessage(msg) if err != nil { return "", 0, fmt.Errorf("encoding message: %w", err) @@ -738,7 +774,7 @@ func streamingRequest(ctx context.Context, serverURL, sessionID, method string, body = data } - req, err := http.NewRequestWithContext(ctx, method, serverURL, bytes.NewReader(body)) + req, err := http.NewRequestWithContext(ctx, s.method, serverURL, bytes.NewReader(body)) if err != nil { return "", 0, fmt.Errorf("creating request: %w", err) } @@ -746,8 +782,13 @@ func streamingRequest(ctx context.Context, serverURL, sessionID, method string, req.Header.Set("Mcp-Session-Id", sessionID) } req.Header.Set("Content-Type", "application/json") - req.Header.Add("Accept", "text/plain") // ensure multiple accept headers are allowed - req.Header.Add("Accept", "application/json, text/event-stream") + if len(s.accept) > 0 { + for _, accept := range s.accept { + req.Header.Add("Accept", accept) + } + } else { + req.Header.Add("Accept", "application/json, text/event-stream") + } resp, err := http.DefaultClient.Do(req) if err != nil { From 0a8fe40c516853f411ac7605317cf85b2d23a98e Mon Sep 17 00:00:00 2001 From: Hugh Palmer Date: Fri, 8 Aug 2025 11:51:53 +0200 Subject: [PATCH 100/221] mcp: changed streamID's from int64 to random strings (#266) - Changed StreamID's to store randomly generated strings. - Updated all tests. - Resolved conflicts --- mcp/event.go | 2 +- mcp/event_test.go | 70 +++++++++++++++++++++--------------------- mcp/streamable.go | 37 ++++++++++------------ mcp/streamable_test.go | 18 +++++------ 4 files changed, 61 insertions(+), 66 deletions(-) diff --git a/mcp/event.go b/mcp/event.go index 9092da76..f4f4eeea 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -392,7 +392,7 @@ func (s *MemoryEventStore) debugString() string { fmt.Fprintf(&b, "; ") } dl := sm[sid] - fmt.Fprintf(&b, "%s %d first=%d", sess, sid, dl.first) + fmt.Fprintf(&b, "%s %s first=%d", sess, sid, dl.first) for _, d := range dl.data { fmt.Fprintf(&b, " %s", d) } diff --git a/mcp/event_test.go b/mcp/event_test.go index 147a947a..601e8300 100644 --- a/mcp/event_test.go +++ b/mcp/event_test.go @@ -119,10 +119,10 @@ func TestMemoryEventStoreState(t *testing.T) { { "appends", func(s *MemoryEventStore) { - appendEvent(s, "S1", 1, "d1") - appendEvent(s, "S1", 2, "d2") - appendEvent(s, "S1", 1, "d3") - appendEvent(s, "S2", 8, "d4") + appendEvent(s, "S1", "1", "d1") + appendEvent(s, "S1", "2", "d2") + appendEvent(s, "S1", "1", "d3") + appendEvent(s, "S2", "8", "d4") }, "S1 1 first=0 d1 d3; S1 2 first=0 d2; S2 8 first=0 d4", 8, @@ -130,10 +130,10 @@ func TestMemoryEventStoreState(t *testing.T) { { "session close", func(s *MemoryEventStore) { - appendEvent(s, "S1", 1, "d1") - appendEvent(s, "S1", 2, "d2") - appendEvent(s, "S1", 1, "d3") - appendEvent(s, "S2", 8, "d4") + appendEvent(s, "S1", "1", "d1") + appendEvent(s, "S1", "2", "d2") + appendEvent(s, "S1", "1", "d3") + appendEvent(s, "S2", "8", "d4") s.SessionClosed(ctx, "S1") }, "S2 8 first=0 d4", @@ -142,10 +142,10 @@ func TestMemoryEventStoreState(t *testing.T) { { "purge", func(s *MemoryEventStore) { - appendEvent(s, "S1", 1, "d1") - appendEvent(s, "S1", 2, "d2") - appendEvent(s, "S1", 1, "d3") - appendEvent(s, "S2", 8, "d4") + appendEvent(s, "S1", "1", "d1") + appendEvent(s, "S1", "2", "d2") + appendEvent(s, "S1", "1", "d3") + appendEvent(s, "S2", "8", "d4") // We are using 8 bytes (d1,d2, d3, d4). // To purge 6, we remove the first of each stream, leaving only d3. s.SetMaxBytes(2) @@ -157,15 +157,15 @@ func TestMemoryEventStoreState(t *testing.T) { { "purge append", func(s *MemoryEventStore) { - appendEvent(s, "S1", 1, "d1") - appendEvent(s, "S1", 2, "d2") - appendEvent(s, "S1", 1, "d3") - appendEvent(s, "S2", 8, "d4") + appendEvent(s, "S1", "1", "d1") + appendEvent(s, "S1", "2", "d2") + appendEvent(s, "S1", "1", "d3") + appendEvent(s, "S2", "8", "d4") s.SetMaxBytes(2) // Up to here, identical to the "purge" case. // Each of these additions will result in a purge. - appendEvent(s, "S1", 2, "d5") // remove d3 - appendEvent(s, "S1", 2, "d6") // remove d5 + appendEvent(s, "S1", "2", "d5") // remove d3 + appendEvent(s, "S1", "2", "d6") // remove d5 }, "S1 1 first=2; S1 2 first=2 d6; S2 8 first=1", 2, @@ -173,15 +173,15 @@ func TestMemoryEventStoreState(t *testing.T) { { "purge resize append", func(s *MemoryEventStore) { - appendEvent(s, "S1", 1, "d1") - appendEvent(s, "S1", 2, "d2") - appendEvent(s, "S1", 1, "d3") - appendEvent(s, "S2", 8, "d4") + appendEvent(s, "S1", "1", "d1") + appendEvent(s, "S1", "2", "d2") + appendEvent(s, "S1", "1", "d3") + appendEvent(s, "S2", "8", "d4") s.SetMaxBytes(2) // Up to here, identical to the "purge" case. s.SetMaxBytes(6) // make room - appendEvent(s, "S1", 2, "d5") - appendEvent(s, "S1", 2, "d6") + appendEvent(s, "S1", "2", "d5") + appendEvent(s, "S1", "2", "d6") }, // The other streams remain, because we may add to them. "S1 1 first=1 d3; S1 2 first=1 d5 d6; S2 8 first=1", @@ -206,10 +206,10 @@ func TestMemoryEventStoreAfter(t *testing.T) { ctx := context.Background() s := NewMemoryEventStore(nil) s.SetMaxBytes(4) - s.Append(ctx, "S1", 1, []byte("d1")) - s.Append(ctx, "S1", 1, []byte("d2")) - s.Append(ctx, "S1", 1, []byte("d3")) - s.Append(ctx, "S1", 2, []byte("d4")) // will purge d1 + s.Append(ctx, "S1", "1", []byte("d1")) + s.Append(ctx, "S1", "1", []byte("d2")) + s.Append(ctx, "S1", "1", []byte("d3")) + s.Append(ctx, "S1", "2", []byte("d4")) // will purge d1 want := "S1 1 first=1 d2 d3; S1 2 first=0 d4" if got := s.debugString(); got != want { t.Fatalf("got state %q, want %q", got, want) @@ -222,14 +222,14 @@ func TestMemoryEventStoreAfter(t *testing.T) { want []string wantErr string // if non-empty, error should contain this string }{ - {"S1", 1, 0, []string{"d2", "d3"}, ""}, - {"S1", 1, 1, []string{"d3"}, ""}, - {"S1", 1, 2, nil, ""}, - {"S1", 2, 0, nil, ""}, - {"S1", 3, 0, nil, "unknown stream ID"}, - {"S2", 0, 0, nil, "unknown session ID"}, + {"S1", "1", 0, []string{"d2", "d3"}, ""}, + {"S1", "1", 1, []string{"d3"}, ""}, + {"S1", "1", 2, nil, ""}, + {"S1", "2", 0, nil, ""}, + {"S1", "3", 0, nil, "unknown stream ID"}, + {"S2", "0", 0, nil, "unknown session ID"}, } { - t.Run(fmt.Sprintf("%s-%d-%d", tt.sessionID, tt.streamID, tt.index), func(t *testing.T) { + t.Run(fmt.Sprintf("%s-%s-%d", tt.sessionID, tt.streamID, tt.index), func(t *testing.T) { var got []string for d, err := range s.After(ctx, tt.sessionID, tt.streamID, tt.index) { if err != nil { diff --git a/mcp/streamable.go b/mcp/streamable.go index 5692b985..7f5ce21b 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -276,7 +276,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) // // It is always text/event-stream, since it must carry arbitrarily many // messages. - t.connection.streams[0] = newStream(0, false) + t.connection.streams[""] = newStream("", false) if t.connection.eventStore == nil { t.connection.eventStore = NewMemoryEventStore(nil) } @@ -334,7 +334,7 @@ func (c *streamableServerConn) SessionID() string { // at any time. type stream struct { // id is the logical ID for the stream, unique within a session. - // ID 0 is used for messages that don't correlate with an incoming request. + // an empty string is used for messages that don't correlate with an incoming request. id StreamID // jsonResponse records whether this stream should respond with application/json @@ -382,9 +382,9 @@ func signalChanPtr() *chan struct{} { return &c } -// A StreamID identifies a stream of SSE events. It is unique within the stream's +// A StreamID identifies a stream of SSE events. It is globally unique. // [ServerSession]. -type StreamID int64 +type StreamID string // We track the incoming request ID inside the handler context using // idContextValue, so that notifications and server->client calls that occur in @@ -434,7 +434,7 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R // It returns an HTTP status code and error message. func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request) { // connID 0 corresponds to the default GET request. - id := StreamID(0) + id := StreamID("") // By default, we haven't seen a last index. Since indices start at 0, we represent // that by -1. This is incremented just before each event is written, in streamResponse // around L407. @@ -462,7 +462,7 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request return } defer stream.signal.Store(nil) - persistent := id == 0 // Only the special stream 0 is a hanging get. + persistent := id == "" // Only the special stream "" is a hanging get. c.respondSSE(stream, w, req, lastIdx, persistent) } @@ -520,7 +520,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // notifications or server->client requests made in the course of handling. // Update accounting for this incoming payload. if len(requests) > 0 { - stream = newStream(StreamID(c.lastStreamID.Add(1)), c.jsonResponse) + stream = newStream(StreamID(randText()), c.jsonResponse) c.mu.Lock() c.streams[stream.id] = stream stream.requests = requests @@ -719,7 +719,7 @@ func (c *streamableServerConn) messages(ctx context.Context, stream *stream, per // // See also [parseEventID]. func formatEventID(sid StreamID, idx int) string { - return fmt.Sprintf("%d_%d", sid, idx) + return fmt.Sprintf("%s_%d", sid, idx) } // parseEventID parses a Last-Event-ID value into a logical stream id and @@ -729,15 +729,12 @@ func formatEventID(sid StreamID, idx int) string { func parseEventID(eventID string) (sid StreamID, idx int, ok bool) { parts := strings.Split(eventID, "_") if len(parts) != 2 { - return 0, 0, false + return "", 0, false } - stream, err := strconv.ParseInt(parts[0], 10, 64) - if err != nil || stream < 0 { - return 0, 0, false - } - idx, err = strconv.Atoi(parts[1]) + stream := StreamID(parts[0]) + idx, err := strconv.Atoi(parts[1]) if err != nil || idx < 0 { - return 0, 0, false + return "", 0, false } return StreamID(stream), idx, true } @@ -778,7 +775,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e // Find the logical connection corresponding to this request. // // For messages sent outside of a request context, this is the default - // connection 0. + // connection "". var forStream StreamID if forRequest.IsValid() { c.mu.Lock() @@ -799,7 +796,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e stream := c.streams[forStream] if stream == nil { - return fmt.Errorf("no stream with ID %d", forStream) + return fmt.Errorf("no stream with ID %s", forStream) } // Special case a few conditions where we fall back on stream 0 (the hanging GET): @@ -809,11 +806,11 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e // // TODO(rfindley): either of these, particularly the first, might be // considered a bug in the server. Report it through a side-channel? - if len(stream.requests) == 0 && forStream != 0 || stream.jsonResponse && !isResponse { - stream = c.streams[0] + if len(stream.requests) == 0 && forStream != "" || stream.jsonResponse && !isResponse { + stream = c.streams[""] } - // TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == 0 + // TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == "" // and the client never did a GET), then memory will grow without bound. Consider a mitigation. stream.outgoing = append(stream.outgoing, data) if isResponse { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index e9bb54e2..55aadb6a 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -888,22 +888,23 @@ func TestEventID(t *testing.T) { sid StreamID idx int }{ - {0, 0}, - {0, 1}, - {1, 0}, - {1, 1}, - {1234, 5678}, + {"0", 0}, + {"0", 1}, + {"1", 0}, + {"1", 1}, + {"", 1}, + {"1234", 5678}, } for _, test := range tests { - t.Run(fmt.Sprintf("%d_%d", test.sid, test.idx), func(t *testing.T) { + t.Run(fmt.Sprintf("%s_%d", test.sid, test.idx), func(t *testing.T) { eventID := formatEventID(test.sid, test.idx) gotSID, gotIdx, ok := parseEventID(eventID) if !ok { t.Fatalf("parseEventID(%q) failed, want ok", eventID) } if gotSID != test.sid || gotIdx != test.idx { - t.Errorf("parseEventID(%q) = %d, %d, want %d, %d", eventID, gotSID, gotIdx, test.sid, test.idx) + t.Errorf("parseEventID(%q) = %s, %d, want %s, %d", eventID, gotSID, gotIdx, test.sid, test.idx) } }) } @@ -912,10 +913,7 @@ func TestEventID(t *testing.T) { "", "_", "1_", - "_1", - "a_1", "1_a", - "-1_1", "1_-1", } From 1afdb1f27309c888be8a4bbb91deb584ede164be Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 13 Aug 2025 21:13:00 +0000 Subject: [PATCH 101/221] mcp: several fixes for streamable client reconnection Do a pass through the streamable client reconnection logic, and fix several bugs. - Establish the initial GET even if MaxRetries is 0 (#256). This was broken because the GET bypasses the initial request and going straight to the SSE GET reconnection logic. - Release the stream ownership when POST requests exit. - Don't reconnect POST requests if we've received the expected response. - Move unexported reconnection config to constants. Otherwise it is too hard to set ReconnectOptions (you have to use the DefaultOptions and mutate). Fixes #256 --- mcp/streamable.go | 122 ++++++++++++++++++++++++--------------- mcp/streamable_test.go | 126 +++++++++++++++++++++++++++-------------- 2 files changed, 160 insertions(+), 88 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 7f5ce21b..e3d80bc3 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -329,6 +329,7 @@ func (c *streamableServerConn) SessionID() string { // A stream is a single logical stream of SSE events within a server session. // A stream begins with a client request, or with a client GET that has // no Last-Event-ID header. +// // A stream ends only when its session ends; we cannot determine its end otherwise, // since a client may send a GET with a Last-Event-ID that references the stream // at any time. @@ -529,6 +530,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } c.mu.Unlock() stream.signal.Store(signalChanPtr()) + defer stream.signal.Store(nil) } // Publish incoming messages. @@ -857,27 +859,27 @@ type StreamableReconnectOptions struct { // MaxRetries is the maximum number of times to attempt a reconnect before giving up. // A value of 0 or less means never retry. MaxRetries int - - // growFactor is the multiplicative factor by which the delay increases after each attempt. - // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. - // It must be 1.0 or greater if MaxRetries is greater than 0. - growFactor float64 - - // initialDelay is the base delay for the first reconnect attempt. - initialDelay time.Duration - - // maxDelay caps the backoff delay, preventing it from growing indefinitely. - maxDelay time.Duration } // DefaultReconnectOptions provides sensible defaults for reconnect logic. var DefaultReconnectOptions = &StreamableReconnectOptions{ - MaxRetries: 5, - growFactor: 1.5, - initialDelay: 1 * time.Second, - maxDelay: 30 * time.Second, + MaxRetries: 5, } +// These settings are not (yet) exposed to the user in +// StreamableReconnectOptions. Since they're invisible, keep them const rather +// than requiring the user to start from DefaultReconnectOptions and mutate. +const ( + // reconnectGrowFactor is the multiplicative factor by which the delay increases after each attempt. + // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. + // It must be 1.0 or greater if MaxRetries is greater than 0. + reconnectGrowFactor = 1.5 + // reconnectInitialDelay is the base delay for the first reconnect attempt. + reconnectInitialDelay = 1 * time.Second + // reconnectMaxDelay caps the backoff delay, preventing it from growing indefinitely. + reconnectMaxDelay = 30 * time.Second +) + // StreamableClientTransportOptions provides options for the // [NewStreamableClientTransport] constructor. // @@ -928,7 +930,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er conn := &streamableClientConn{ url: t.Endpoint, client: client, - incoming: make(chan []byte, 100), + incoming: make(chan jsonrpc.Message, 10), done: make(chan struct{}), ReconnectOptions: reconnOpts, ctx: connCtx, @@ -944,7 +946,7 @@ type streamableClientConn struct { client *http.Client ctx context.Context cancel context.CancelFunc - incoming chan []byte + incoming chan jsonrpc.Message // Guard calls to Close, as it may be called multiple times. closeOnce sync.Once @@ -988,7 +990,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { // § 2.5: A server using the Streamable HTTP transport MAY assign a session // ID at initialization time, by including it in an Mcp-Session-Id header // on the HTTP response containing the InitializeResult. - go c.handleSSE(nil, true) + go c.handleSSE(nil, true, nil) } // fail handles an asynchronous error while reading. @@ -1031,8 +1033,8 @@ func (c *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error return nil, c.failure() case <-c.done: return nil, io.EOF - case data := <-c.incoming: - return jsonrpc2.DecodeMessage(data) + case msg := <-c.incoming: + return msg, nil } } @@ -1042,7 +1044,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return err } - data, err := jsonrpc2.EncodeMessage(msg) + data, err := jsonrpc.EncodeMessage(msg) if err != nil { return err } @@ -1088,7 +1090,8 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e go c.handleJSON(resp) case "text/event-stream": - go c.handleSSE(resp, false) + jsonReq, _ := msg.(*jsonrpc.Request) + go c.handleSSE(resp, false, jsonReq) default: resp.Body.Close() @@ -1116,8 +1119,13 @@ func (c *streamableClientConn) handleJSON(resp *http.Response) { c.fail(err) return } + msg, err := jsonrpc.DecodeMessage(body) + if err != nil { + c.fail(fmt.Errorf("failed to decode response: %v", err)) + return + } select { - case c.incoming <- body: + case c.incoming <- msg: case <-c.done: // The connection was closed by the client; exit gracefully. } @@ -1125,21 +1133,26 @@ func (c *streamableClientConn) handleJSON(resp *http.Response) { // handleSSE manages the lifecycle of an SSE connection. It can be either // persistent (for the main GET listener) or temporary (for a POST response). -func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool) { +// +// If forReq is set, it is the request that initiated the stream, and the +// stream is complete when we receive its response. +func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) { resp := initialResp var lastEventID string for { - eventID, clientClosed := c.processStream(resp) - lastEventID = eventID + if resp != nil { + eventID, clientClosed := c.processStream(resp, forReq) + lastEventID = eventID - // If the connection was closed by the client, we're done. - if clientClosed { - return - } - // If the stream has ended, then do not reconnect if the stream is - // temporary (POST initiated SSE). - if lastEventID == "" && !persistent { - return + // If the connection was closed by the client, we're done. + if clientClosed { + return + } + // If the stream has ended, then do not reconnect if the stream is + // temporary (POST initiated SSE). + if lastEventID == "" && !persistent { + return + } } // The stream was interrupted or ended by the server. Attempt to reconnect. @@ -1159,12 +1172,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent // incoming channel. It returns the ID of the last processed event and a flag // indicating if the connection was closed by the client. If resp is nil, it // returns "", false. -func (c *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) { - if resp == nil { - // TODO(rfindley): avoid this special handling. - return "", false - } - +func (c *streamableClientConn) processStream(resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) { defer resp.Body.Close() for evt, err := range scanEvents(resp.Body) { if err != nil { @@ -1175,8 +1183,21 @@ func (c *streamableClientConn) processStream(resp *http.Response) (lastEventID s lastEventID = evt.ID } + msg, err := jsonrpc.DecodeMessage(evt.Data) + if err != nil { + c.fail(fmt.Errorf("failed to decode event: %v", err)) + return "", true + } + select { - case c.incoming <- evt.Data: + case c.incoming <- msg: + if jsonResp, ok := msg.(*jsonrpc.Response); ok && forReq != nil { + // TODO: we should never get a response when forReq is nil (the hanging GET). + // We should detect this case, and eliminate the 'persistent' flag arguments. + if jsonResp.ID == forReq.ID { + return "", true + } + } case <-c.done: // The connection was closed by the client; exit gracefully. return "", true @@ -1192,11 +1213,20 @@ func (c *streamableClientConn) processStream(resp *http.Response) (lastEventID s func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) { var finalErr error - for attempt := 0; attempt < c.ReconnectOptions.MaxRetries; attempt++ { + // We can reach the 'reconnect' path through the hanging GET, in which case + // lastEventID will be "". + // + // In this case, we need an initial attempt. + attempt := 0 + if lastEventID != "" { + attempt = 1 + } + + for ; attempt <= c.ReconnectOptions.MaxRetries; attempt++ { select { case <-c.done: return nil, fmt.Errorf("connection closed by client during reconnect") - case <-time.After(calculateReconnectDelay(c.ReconnectOptions, attempt)): + case <-time.After(calculateReconnectDelay(attempt)): resp, err := c.establishSSE(lastEventID) if err != nil { finalErr = err // Store the error and try again. @@ -1267,11 +1297,11 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response, } // calculateReconnectDelay calculates a delay using exponential backoff with full jitter. -func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time.Duration { +func calculateReconnectDelay(attempt int) time.Duration { // Calculate the exponential backoff using the grow factor. - backoffDuration := time.Duration(float64(opts.initialDelay) * math.Pow(opts.growFactor, float64(attempt))) + backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt))) // Cap the backoffDuration at maxDelay. - backoffDuration = min(backoffDuration, opts.maxDelay) + backoffDuration = min(backoffDuration, reconnectMaxDelay) // Use a full jitter using backoffDuration jitter := rand.N(backoffDuration) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 55aadb6a..11600fbc 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -16,6 +16,7 @@ import ( "net/http/httptest" "net/http/httputil" "net/url" + "sort" "strings" "sync" "sync/atomic" @@ -186,12 +187,30 @@ func TestStreamableTransports(t *testing.T) { } } -// TestClientReplay verifies that the client can recover from a -// mid-stream network failure and receive replayed messages. It uses a proxy -// that is killed and restarted to simulate a recoverable network outage. +// TestClientReplay verifies that the client can recover from a mid-stream +// network failure and receive replayed messages (if replay is configured). It +// uses a proxy that is killed and restarted to simulate a recoverable network +// outage. func TestClientReplay(t *testing.T) { + for _, test := range []clientReplayTest{ + {"default", nil, true}, + {"no retries", &StreamableReconnectOptions{}, false}, + } { + t.Run(test.name, func(t *testing.T) { + testClientReplay(t, test) + }) + } +} + +type clientReplayTest struct { + name string + options *StreamableReconnectOptions + wantRecovered bool +} + +func testClientReplay(t *testing.T, test clientReplayTest) { notifications := make(chan string) - // 1. Configure the real MCP server. + // Configure the real MCP server. server := NewServer(testImpl, nil) // Use a channel to synchronize the server's message sending with the test's @@ -200,23 +219,24 @@ func TestClientReplay(t *testing.T) { serverClosed := make(chan struct{}) server.AddTool(&Tool{Name: "multiMessageTool", InputSchema: &jsonschema.Schema{}}, func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { - go func() { - bgCtx := context.Background() - // Send the first two messages immediately. - req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg1"}) - req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"}) - - // Signal the test that it can now kill the proxy. - close(serverReadyToKillProxy) - <-serverClosed - - // These messages should be queued for replay by the server after - // the client's connection drops. - req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg3"}) - req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"}) - }() - return &CallToolResult{}, nil + // Send one message to the request context, and another to a background + // context (which will end up on the hanging GET). + + bgCtx := context.Background() + req.Session.NotifyProgress(ctx, &ProgressNotificationParams{Message: "msg1"}) + req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"}) + + // Signal the test that it can now kill the proxy. + close(serverReadyToKillProxy) + <-serverClosed + + // These messages should be queued for replay by the server after + // the client's connection drops. + req.Session.NotifyProgress(ctx, &ProgressNotificationParams{Message: "msg3"}) + req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"}) + return new(CallToolResult), nil }) + realServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) defer realServer.Close() realServerURL, err := url.Parse(realServer.URL) @@ -224,12 +244,12 @@ func TestClientReplay(t *testing.T) { t.Fatalf("Failed to parse real server URL: %v", err) } - // 2. Configure a proxy that sits between the client and the real server. + // Configure a proxy that sits between the client and the real server. proxyHandler := httputil.NewSingleHostReverseProxy(realServerURL) proxy := httptest.NewServer(proxyHandler) proxyAddr := proxy.Listener.Addr().String() // Get the address to restart it later. - // 3. Configure the client to connect to the proxy with default options. + // Configure the client to connect to the proxy with default options. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() client := NewClient(testImpl, &ClientOptions{ @@ -237,20 +257,24 @@ func TestClientReplay(t *testing.T) { notifications <- req.Params.Message }, }) - clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: proxy.URL}, nil) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{ + Endpoint: proxy.URL, + ReconnectOptions: test.options, + }, nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) } defer clientSession.Close() - clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"}) - - // 4. Read and verify messages until the server signals it's ready for the proxy kill. - receivedNotifications := readNotifications(t, ctx, notifications, 2) - wantReceived := []string{"msg1", "msg2"} - if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" { - t.Errorf("Received notifications mismatch (-want +got):\n%s", diff) - } + var ( + wg sync.WaitGroup + callErr error + ) + wg.Add(1) + go func() { + defer wg.Done() + _, callErr = clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"}) + }() select { case <-serverReadyToKillProxy: @@ -259,32 +283,50 @@ func TestClientReplay(t *testing.T) { t.Fatalf("Context timed out before server was ready to kill proxy") } - // 5. Simulate a total network failure by closing the proxy. + // We should always get the first two notifications. + msgs := readNotifications(t, ctx, notifications, 2) + sort.Strings(msgs) // notifications may arrive in either order + want := []string{"msg1", "msg2"} + if diff := cmp.Diff(want, msgs); diff != "" { + t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) + } + + // Simulate a total network failure by closing the proxy. t.Log("--- Killing proxy to simulate network failure ---") proxy.CloseClientConnections() proxy.Close() close(serverClosed) - // 6. Simulate network recovery by restarting the proxy on the same address. + // Simulate network recovery by restarting the proxy on the same address. t.Logf("--- Restarting proxy on %s ---", proxyAddr) listener, err := net.Listen("tcp", proxyAddr) if err != nil { t.Fatalf("Failed to listen on proxy address: %v", err) } + restartedProxy := &http.Server{Handler: proxyHandler} go restartedProxy.Serve(listener) defer restartedProxy.Close() - // 7. Continue reading from the same connection object. - // Its internal logic should successfully retry, reconnect to the new proxy, - // and receive the replayed messages. - recoveredNotifications := readNotifications(t, ctx, notifications, 2) + wg.Wait() - // 8. Verify the correct messages were received on the recovered connection. - wantRecovered := []string{"msg3", "msg4"} - - if diff := cmp.Diff(wantRecovered, recoveredNotifications); diff != "" { - t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) + if test.wantRecovered { + // If we've recovered, we should get all 4 notifications and the tool call + // should have succeeded. + msgs := readNotifications(t, ctx, notifications, 2) + sort.Strings(msgs) + want := []string{"msg3", "msg4"} + if diff := cmp.Diff(want, msgs); diff != "" { + t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) + } + if callErr != nil { + t.Errorf("CallTool failed unexpectedly: %v", err) + } + } else { + // Otherwise, the call should fail. + if callErr == nil { + t.Errorf("CallTool succeeded unexpectedly") + } } } From eb5eb06341f74004a36af4026a463489581a9628 Mon Sep 17 00:00:00 2001 From: Manuel Martinez Date: Mon, 18 Aug 2025 09:11:25 -0700 Subject: [PATCH 102/221] ci: add go vet to lint step (#316) - Adds `go vet` call to the lint step in GA workflow. - Enforces a compatible `go-version` with the minimum in the matrix for the `test` step. - Fixes some minor `.yml` formatting Fixes #281 --- There's some discussion in #281 about adding [staticcheck](https://staticcheck.dev/docs/) as well but when I ran it, it found some hits (see below) that are presumably not severe, and I wouldn't want to block deployments before getting some feedback from maintainers. ``` internal/jsonrpc2/conn.go:695:4: this value of writeErr is never used (SA4006) internal/jsonrpc2/conn.go:697:5: this value of err is never used (SA4006) internal/jsonrpc2/conn.go:700:4: this value of err is never used (SA4006) mcp/mcp_test.go:501:2: var resource3 is unused (U1000) mcp/mcp_test.go:509:5: var embeddedResources is unused (U1000) mcp/mcp_test.go:513:6: func handleEmbeddedResource is unused (U1000) mcp/streamable.go:124:3: unnecessary assignment to the blank identifier (S1005) mcp/streamable.go:291:2: field lastStreamID is unused (U1000) mcp/streamable.go:293:2: field opts is unused (U1000) mcp/streamable_test.go:542:29: argument ctx is overwritten before first use (SA4009) ``` --- .github/workflows/test.yml | 66 ++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 54b36331..5eb2dacd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -3,7 +3,7 @@ on: # Manual trigger workflow_dispatch: push: - branches: main + branches: [main] pull_request: permissions: @@ -13,43 +13,47 @@ jobs: lint: runs-on: ubuntu-latest steps: - - name: Check out code - uses: actions/checkout@v4 - - name: Set up Go - uses: actions/setup-go@v5 - - name: Check formatting - run: | - unformatted=$(gofmt -l .) - if [ -n "$unformatted" ]; then - echo "The following files are not properly formatted:" - echo "$unformatted" - exit 1 - fi - echo "All Go files are properly formatted" + - name: Check out code + uses: actions/checkout@v4 + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "^1.23" + - name: Check formatting + run: | + unformatted=$(gofmt -l .) + if [ -n "$unformatted" ]; then + echo "The following files are not properly formatted:" + echo "$unformatted" + exit 1 + fi + echo "All Go files are properly formatted" + - name: Run Go vet + run: go vet ./... test: runs-on: ubuntu-latest strategy: matrix: - go: [ '1.23', '1.24', '1.25.0-rc.3' ] + go: ["1.23", "1.24", "1.25.0-rc.3"] steps: - - name: Check out code - uses: actions/checkout@v4 - - name: Set up Go - uses: actions/setup-go@v5 - with: - go-version: ${{ matrix.go }} - - name: Test - run: go test -v ./... + - name: Check out code + uses: actions/checkout@v4 + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go }} + - name: Test + run: go test -v ./... race-test: runs-on: ubuntu-latest steps: - - name: Check out code - uses: actions/checkout@v4 - - name: Set up Go - uses: actions/setup-go@v5 - with: - go-version: '1.24' - - name: Test with -race - run: go test -v -race ./... + - name: Check out code + uses: actions/checkout@v4 + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.24" + - name: Test with -race + run: go test -v -race ./... From 767dacb820454d1aeef21a514db2f03fb05543a5 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Mon, 18 Aug 2025 13:53:34 -0400 Subject: [PATCH 103/221] mcp/streamable: remove StreamableReconnectOptions (#319) Moves MaxRetries into the parent struct and assume a negative integer signifies to not retry. Fixes #308 --- mcp/streamable.go | 69 +++++++++++++++++++----------------------- mcp/streamable_test.go | 10 +++--- 2 files changed, 36 insertions(+), 43 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index e3d80bc3..9ae20c02 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -849,26 +849,15 @@ func (c *streamableServerConn) Close() error { // endpoint serving the streamable HTTP transport defined by the 2025-03-26 // version of the spec. type StreamableClientTransport struct { - Endpoint string - HTTPClient *http.Client - ReconnectOptions *StreamableReconnectOptions -} - -// StreamableReconnectOptions defines parameters for client reconnect attempts. -type StreamableReconnectOptions struct { + Endpoint string + HTTPClient *http.Client // MaxRetries is the maximum number of times to attempt a reconnect before giving up. - // A value of 0 or less means never retry. + // It defaults to 5. To disable retries, use a negative number. MaxRetries int } -// DefaultReconnectOptions provides sensible defaults for reconnect logic. -var DefaultReconnectOptions = &StreamableReconnectOptions{ - MaxRetries: 5, -} - // These settings are not (yet) exposed to the user in -// StreamableReconnectOptions. Since they're invisible, keep them const rather -// than requiring the user to start from DefaultReconnectOptions and mutate. +// StreamableClientTransport. const ( // reconnectGrowFactor is the multiplicative factor by which the delay increases after each attempt. // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. @@ -887,8 +876,10 @@ const ( type StreamableClientTransportOptions struct { // HTTPClient is the client to use for making HTTP requests. If nil, // http.DefaultClient is used. - HTTPClient *http.Client - ReconnectOptions *StreamableReconnectOptions + HTTPClient *http.Client + // MaxRetries is the maximum number of times to attempt a reconnect before giving up. + // It defaults to 5. To disable retries, use a negative number. + MaxRetries int } // NewStreamableClientTransport returns a new client transport that connects to @@ -901,7 +892,7 @@ func NewStreamableClientTransport(url string, opts *StreamableClientTransportOpt t := &StreamableClientTransport{Endpoint: url} if opts != nil { t.HTTPClient = opts.HTTPClient - t.ReconnectOptions = opts.ReconnectOptions + t.MaxRetries = opts.MaxRetries } return t } @@ -919,34 +910,36 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er if client == nil { client = http.DefaultClient } - reconnOpts := t.ReconnectOptions - if reconnOpts == nil { - reconnOpts = DefaultReconnectOptions + maxRetries := t.MaxRetries + if maxRetries == 0 { + maxRetries = 5 + } else if maxRetries < 0 { + maxRetries = 0 } // Create a new cancellable context that will manage the connection's lifecycle. // This is crucial for cleanly shutting down the background SSE listener by // cancelling its blocking network operations, which prevents hangs on exit. connCtx, cancel := context.WithCancel(context.Background()) conn := &streamableClientConn{ - url: t.Endpoint, - client: client, - incoming: make(chan jsonrpc.Message, 10), - done: make(chan struct{}), - ReconnectOptions: reconnOpts, - ctx: connCtx, - cancel: cancel, - failed: make(chan struct{}), + url: t.Endpoint, + client: client, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + maxRetries: maxRetries, + ctx: connCtx, + cancel: cancel, + failed: make(chan struct{}), } return conn, nil } type streamableClientConn struct { - url string - ReconnectOptions *StreamableReconnectOptions - client *http.Client - ctx context.Context - cancel context.CancelFunc - incoming chan jsonrpc.Message + url string + client *http.Client + ctx context.Context + cancel context.CancelFunc + incoming chan jsonrpc.Message + maxRetries int // Guard calls to Close, as it may be called multiple times. closeOnce sync.Once @@ -1222,7 +1215,7 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er attempt = 1 } - for ; attempt <= c.ReconnectOptions.MaxRetries; attempt++ { + for ; attempt <= c.maxRetries; attempt++ { select { case <-c.done: return nil, fmt.Errorf("connection closed by client during reconnect") @@ -1244,9 +1237,9 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er } // If the loop completes, all retries have failed. if finalErr != nil { - return nil, fmt.Errorf("connection failed after %d attempts: %w", c.ReconnectOptions.MaxRetries, finalErr) + return nil, fmt.Errorf("connection failed after %d attempts: %w", c.maxRetries, finalErr) } - return nil, fmt.Errorf("connection failed after %d attempts", c.ReconnectOptions.MaxRetries) + return nil, fmt.Errorf("connection failed after %d attempts", c.maxRetries) } // isResumable checks if an HTTP response indicates a valid SSE stream that can be processed. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 11600fbc..9c7f5f9c 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -193,8 +193,8 @@ func TestStreamableTransports(t *testing.T) { // outage. func TestClientReplay(t *testing.T) { for _, test := range []clientReplayTest{ - {"default", nil, true}, - {"no retries", &StreamableReconnectOptions{}, false}, + {"default", 0, true}, + {"no retries", -1, false}, } { t.Run(test.name, func(t *testing.T) { testClientReplay(t, test) @@ -204,7 +204,7 @@ func TestClientReplay(t *testing.T) { type clientReplayTest struct { name string - options *StreamableReconnectOptions + maxRetries int wantRecovered bool } @@ -258,8 +258,8 @@ func testClientReplay(t *testing.T, test clientReplayTest) { }, }) clientSession, err := client.Connect(ctx, &StreamableClientTransport{ - Endpoint: proxy.URL, - ReconnectOptions: test.options, + Endpoint: proxy.URL, + MaxRetries: test.maxRetries, }, nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) From 75f99995901485751d7b017ddb09abd2e9ab2a1d Mon Sep 17 00:00:00 2001 From: Matt L Date: Mon, 18 Aug 2025 14:49:02 -0400 Subject: [PATCH 104/221] feat(mcp): export initialized result in client session (#320) Export ServerCapabilities and related subtypes and update protocol, server, and tests to use exported types. Add InitializeResult method on ClientSession for accessing the initialize result from sessions. Fixes: https://github.com/modelcontextprotocol/go-sdk/issues/166 Co-authored-by: Shashank Pachava --- mcp/client.go | 6 ++++ mcp/protocol.go | 24 ++++++++-------- mcp/server.go | 14 ++++----- mcp/server_test.go | 64 +++++++++++++++++++++--------------------- mcp/streamable_test.go | 13 +++++---- 5 files changed, 64 insertions(+), 57 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 88eea7da..65a7a954 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -177,6 +177,12 @@ type clientSessionState struct { InitializeResult *InitializeResult } +func (cs *ClientSession) InitializeResult() *InitializeResult { return cs.state.InitializeResult } + +func (cs *ClientSession) setConn(c Connection) { + cs.mcpConn = c +} + func (cs *ClientSession) ID() string { if c, ok := cs.mcpConn.(hasSessionID); ok { return c.SessionID() diff --git a/mcp/protocol.go b/mcp/protocol.go index d2d343b8..7fbfccf0 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -338,7 +338,7 @@ type InitializeResult struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` - Capabilities *serverCapabilities `json:"capabilities"` + Capabilities *ServerCapabilities `json:"capabilities"` // Instructions describing how to use the server and its features. // // This can be used by clients to improve the LLM's understanding of available @@ -971,19 +971,19 @@ type Implementation struct { } // Present if the server supports argument autocompletion suggestions. -type completionCapabilities struct{} +type CompletionCapabilities struct{} // Present if the server supports sending log messages to the client. -type loggingCapabilities struct{} +type LoggingCapabilities struct{} // Present if the server offers any prompt templates. -type promptCapabilities struct { +type PromptCapabilities struct { // Whether this server supports notifications for changes to the prompt list. ListChanged bool `json:"listChanged,omitempty"` } // Present if the server offers any resources to read. -type resourceCapabilities struct { +type ResourceCapabilities struct { // Whether this server supports notifications for changes to the resource list. ListChanged bool `json:"listChanged,omitempty"` // Whether this server supports subscribing to resource updates. @@ -993,23 +993,23 @@ type resourceCapabilities struct { // Capabilities that a server may support. Known capabilities are defined here, // in this schema, but this is not a closed set: any server can define its own, // additional capabilities. -type serverCapabilities struct { +type ServerCapabilities struct { // Present if the server supports argument autocompletion suggestions. - Completions *completionCapabilities `json:"completions,omitempty"` + Completions *CompletionCapabilities `json:"completions,omitempty"` // Experimental, non-standard capabilities that the server supports. Experimental map[string]struct{} `json:"experimental,omitempty"` // Present if the server supports sending log messages to the client. - Logging *loggingCapabilities `json:"logging,omitempty"` + Logging *LoggingCapabilities `json:"logging,omitempty"` // Present if the server offers any prompt templates. - Prompts *promptCapabilities `json:"prompts,omitempty"` + Prompts *PromptCapabilities `json:"prompts,omitempty"` // Present if the server offers any resources to read. - Resources *resourceCapabilities `json:"resources,omitempty"` + Resources *ResourceCapabilities `json:"resources,omitempty"` // Present if the server offers any tools to call. - Tools *toolCapabilities `json:"tools,omitempty"` + Tools *ToolCapabilities `json:"tools,omitempty"` } // Present if the server offers any tools to call. -type toolCapabilities struct { +type ToolCapabilities struct { // Whether this server supports notifications for changes to the tool list. ListChanged bool `json:"listChanged,omitempty"` } diff --git a/mcp/server.go b/mcp/server.go index ed4ec720..5bc626b3 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -242,27 +242,27 @@ func (s *Server) RemoveResourceTemplates(uriTemplates ...string) { func() bool { return s.resourceTemplates.remove(uriTemplates...) }) } -func (s *Server) capabilities() *serverCapabilities { +func (s *Server) capabilities() *ServerCapabilities { s.mu.Lock() defer s.mu.Unlock() - caps := &serverCapabilities{ - Logging: &loggingCapabilities{}, + caps := &ServerCapabilities{ + Logging: &LoggingCapabilities{}, } if s.opts.HasTools || s.tools.len() > 0 { - caps.Tools = &toolCapabilities{ListChanged: true} + caps.Tools = &ToolCapabilities{ListChanged: true} } if s.opts.HasPrompts || s.prompts.len() > 0 { - caps.Prompts = &promptCapabilities{ListChanged: true} + caps.Prompts = &PromptCapabilities{ListChanged: true} } if s.opts.HasResources || s.resources.len() > 0 || s.resourceTemplates.len() > 0 { - caps.Resources = &resourceCapabilities{ListChanged: true} + caps.Resources = &ResourceCapabilities{ListChanged: true} if s.opts.SubscribeHandler != nil { caps.Resources.Subscribe = true } } if s.opts.CompletionHandler != nil { - caps.Completions = &completionCapabilities{} + caps.Completions = &CompletionCapabilities{} } return caps } diff --git a/mcp/server_test.go b/mcp/server_test.go index 202ab5d9..5482d51f 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -236,13 +236,13 @@ func TestServerCapabilities(t *testing.T) { name string configureServer func(s *Server) serverOpts ServerOptions - wantCapabilities *serverCapabilities + wantCapabilities *ServerCapabilities }{ { name: "No capabilities", configureServer: func(s *Server) {}, - wantCapabilities: &serverCapabilities{ - Logging: &loggingCapabilities{}, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, }, }, { @@ -250,9 +250,9 @@ func TestServerCapabilities(t *testing.T) { configureServer: func(s *Server) { s.AddPrompt(&Prompt{Name: "p"}, nil) }, - wantCapabilities: &serverCapabilities{ - Logging: &loggingCapabilities{}, - Prompts: &promptCapabilities{ListChanged: true}, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Prompts: &PromptCapabilities{ListChanged: true}, }, }, { @@ -260,9 +260,9 @@ func TestServerCapabilities(t *testing.T) { configureServer: func(s *Server) { s.AddResource(&Resource{URI: "file:///r"}, nil) }, - wantCapabilities: &serverCapabilities{ - Logging: &loggingCapabilities{}, - Resources: &resourceCapabilities{ListChanged: true}, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Resources: &ResourceCapabilities{ListChanged: true}, }, }, { @@ -270,9 +270,9 @@ func TestServerCapabilities(t *testing.T) { configureServer: func(s *Server) { s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) }, - wantCapabilities: &serverCapabilities{ - Logging: &loggingCapabilities{}, - Resources: &resourceCapabilities{ListChanged: true}, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Resources: &ResourceCapabilities{ListChanged: true}, }, }, { @@ -288,9 +288,9 @@ func TestServerCapabilities(t *testing.T) { return nil }, }, - wantCapabilities: &serverCapabilities{ - Logging: &loggingCapabilities{}, - Resources: &resourceCapabilities{ListChanged: true, Subscribe: true}, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Resources: &ResourceCapabilities{ListChanged: true, Subscribe: true}, }, }, { @@ -298,9 +298,9 @@ func TestServerCapabilities(t *testing.T) { configureServer: func(s *Server) { s.AddTool(tool, nil) }, - wantCapabilities: &serverCapabilities{ - Logging: &loggingCapabilities{}, - Tools: &toolCapabilities{ListChanged: true}, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, }, }, { @@ -311,9 +311,9 @@ func TestServerCapabilities(t *testing.T) { return nil, nil }, }, - wantCapabilities: &serverCapabilities{ - Logging: &loggingCapabilities{}, - Completions: &completionCapabilities{}, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Completions: &CompletionCapabilities{}, }, }, { @@ -335,12 +335,12 @@ func TestServerCapabilities(t *testing.T) { return nil, nil }, }, - wantCapabilities: &serverCapabilities{ - Completions: &completionCapabilities{}, - Logging: &loggingCapabilities{}, - Prompts: &promptCapabilities{ListChanged: true}, - Resources: &resourceCapabilities{ListChanged: true, Subscribe: true}, - Tools: &toolCapabilities{ListChanged: true}, + wantCapabilities: &ServerCapabilities{ + Completions: &CompletionCapabilities{}, + Logging: &LoggingCapabilities{}, + Prompts: &PromptCapabilities{ListChanged: true}, + Resources: &ResourceCapabilities{ListChanged: true, Subscribe: true}, + Tools: &ToolCapabilities{ListChanged: true}, }, }, { @@ -351,11 +351,11 @@ func TestServerCapabilities(t *testing.T) { HasResources: true, HasTools: true, }, - wantCapabilities: &serverCapabilities{ - Logging: &loggingCapabilities{}, - Prompts: &promptCapabilities{ListChanged: true}, - Resources: &resourceCapabilities{ListChanged: true}, - Tools: &toolCapabilities{ListChanged: true}, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Prompts: &PromptCapabilities{ListChanged: true}, + Resources: &ResourceCapabilities{ListChanged: true}, + Tools: &ToolCapabilities{ListChanged: true}, }, }, } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 9c7f5f9c..4181303f 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -409,9 +409,9 @@ func TestStreamableServerTransport(t *testing.T) { // Predefined steps, to avoid repetition below. initReq := req(1, methodInitialize, &InitializeParams{}) initResp := resp(1, &InitializeResult{ - Capabilities: &serverCapabilities{ - Logging: &loggingCapabilities{}, - Tools: &toolCapabilities{ListChanged: true}, + Capabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, }, ProtocolVersion: latestProtocolVersion, ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, @@ -891,9 +891,10 @@ func TestStreamableClientTransportApplicationJSON(t *testing.T) { } } initResult := &InitializeResult{ - Capabilities: &serverCapabilities{ - Logging: &loggingCapabilities{}, - Tools: &toolCapabilities{ListChanged: true}, + Capabilities: &ServerCapabilities{ + Completions: &CompletionCapabilities{}, + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, }, ProtocolVersion: latestProtocolVersion, ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, From 52734fd4144bfd41cd5ef280f399b79d69bd99cf Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 19 Aug 2025 07:49:56 -0400 Subject: [PATCH 105/221] mcp: rename logging symbols (#321) Fixes #279. --- mcp/client.go | 5 ++--- mcp/mcp_test.go | 2 +- mcp/protocol.go | 8 ++++---- mcp/server.go | 2 +- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 65a7a954..dcabbfd6 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -103,8 +103,7 @@ func (e unsupportedProtocolVersionError) Error() string { } // ClientSessionOptions is reserved for future use. -type ClientSessionOptions struct { -} +type ClientSessionOptions struct{} // Connect begins an MCP session by connecting to a server over the given // transport, and initializing the session. @@ -398,7 +397,7 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) ( return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) } -func (cs *ClientSession) SetLevel(ctx context.Context, params *SetLevelParams) error { +func (cs *ClientSession) SetLoggingLevel(ctx context.Context, params *SetLoggingLevelParams) error { _, err := handleSend[*emptyResult](ctx, methodSetLevel, newClientRequest(cs, orZero[Params](params))) return err } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index d04235b8..4a10d304 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -404,7 +404,7 @@ func TestEndToEnd(t *testing.T) { // Nothing should be logged until the client sets a level. mustLog("info", "before") - if err := cs.SetLevel(ctx, &SetLevelParams{Level: "warning"}); err != nil { + if err := cs.SetLoggingLevel(ctx, &SetLoggingLevelParams{Level: "warning"}); err != nil { t.Fatal(err) } mustLog("warning", want[0].Data) diff --git a/mcp/protocol.go b/mcp/protocol.go index 7fbfccf0..0125cb13 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -826,7 +826,7 @@ func (m *SamplingMessage) UnmarshalJSON(data []byte) error { return nil } -type SetLevelParams struct { +type SetLoggingLevelParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` @@ -836,9 +836,9 @@ type SetLevelParams struct { Level LoggingLevel `json:"level"` } -func (x *SetLevelParams) isParams() {} -func (x *SetLevelParams) GetProgressToken() any { return getProgressToken(x) } -func (x *SetLevelParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *SetLoggingLevelParams) isParams() {} +func (x *SetLoggingLevelParams) GetProgressToken() any { return getProgressToken(x) } +func (x *SetLoggingLevelParams) SetProgressToken(t any) { setProgressToken(x, t) } // Definition for a tool the client can call. type Tool struct { diff --git a/mcp/server.go b/mcp/server.go index 5bc626b3..1d251219 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -857,7 +857,7 @@ func (ss *ServerSession) cancel(context.Context, *CancelledParams) (Result, erro return nil, nil } -func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*emptyResult, error) { +func (ss *ServerSession) setLevel(_ context.Context, params *SetLoggingLevelParams) (*emptyResult, error) { ss.updateState(func(state *ServerSessionState) { state.LogLevel = params.Level }) From 112ca4e3b10aef5fa01d1e7e753eecd5c106d603 Mon Sep 17 00:00:00 2001 From: cryo Date: Sun, 17 Aug 2025 09:21:45 +0000 Subject: [PATCH 106/221] mcp: Ensure keepalive goroutine is started exactly once --- mcp/server.go | 6 +++--- mcp/server_test.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 1d251219..d118464a 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -602,9 +602,6 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar // params are non-nil. params = new(InitializedParams) } - if ss.server.opts.KeepAlive > 0 { - ss.startKeepalive(ss.server.opts.KeepAlive) - } var wasInit, wasInitd bool ss.updateState(func(state *ServerSessionState) { wasInit = state.InitializeParams != nil @@ -620,6 +617,9 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar if wasInitd { return nil, fmt.Errorf("duplicate %q received", notificationInitialized) } + if ss.server.opts.KeepAlive > 0 { + ss.startKeepalive(ss.server.opts.KeepAlive) + } if h := ss.server.opts.InitializedHandler; h != nil { h(ctx, serverRequestFor(ss, params)) } diff --git a/mcp/server_test.go b/mcp/server_test.go index 5482d51f..adadc9c3 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -9,6 +9,7 @@ import ( "log" "slices" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/jsonschema-go/jsonschema" @@ -371,3 +372,51 @@ func TestServerCapabilities(t *testing.T) { }) } } + +// TestServerSessionkeepaliveCancelOverwritten is to verify that `ServerSession.keepaliveCancel` is assigned exactly once, +// ensuring that only a single goroutine is responsible for the session's keepalive ping mechanism. +func TestServerSessionkeepaliveCancelOverwritten(t *testing.T) { + // Set KeepAlive to a long duration to ensure the keepalive + // goroutine stays alive for the duration of the test without actually sending + // ping requests, since we don't have a real client connection established. + server := NewServer(testImpl, &ServerOptions{KeepAlive: 5 * time.Second}) + ss := &ServerSession{server: server} + + // 1. Initialize the session. + _, err := ss.initialize(context.Background(), &InitializeParams{}) + if err != nil { + t.Fatalf("ServerSession initialize failed: %v", err) + } + + // 2. Call 'initialized' for the first time. This should start the keepalive mechanism. + _, err = ss.initialized(context.Background(), &InitializedParams{}) + if err != nil { + t.Fatalf("First initialized call failed: %v", err) + } + if ss.keepaliveCancel == nil { + t.Fatalf("expected ServerSession.keepaliveCancel to be set after the first call of initialized") + } + + // Save the cancel function and use defer to ensure resources are cleaned up. + firstCancel := ss.keepaliveCancel + defer firstCancel() + + // 3. Manually set the field to nil. + // Do this to facilitate the test's core assertion. The goal is to verify that + // 'ss.keepaliveCancel' is not assigned a second time. By setting it to nil, + // we can easily check after the next call if a new keepalive goroutine was started. + ss.keepaliveCancel = nil + + // 4. Call 'initialized' for the second time. This should return an error. + _, err = ss.initialized(context.Background(), &InitializedParams{}) + if err == nil { + t.Fatalf("Expected 'duplicate initialized received' error on second call, got nil") + } + + // 5. Re-check the field to ensure it remains nil. + // Since 'initialized' correctly returned an error and did not call + // 'startKeepalive', the field should remain unchanged. + if ss.keepaliveCancel != nil { + t.Fatal("expected ServerSession.keepaliveCancel to be nil after we manually niled it and re-initialized") + } +} From 3db848abec869b29015b8c5bee5bc335fa928b09 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Mon, 18 Aug 2025 18:15:06 +0000 Subject: [PATCH 107/221] mcp: implement a concurrency model for calls Implement the concurrency model described in #26: notifications are synchronous, but calls are asynchronous (except for 'initialize'). To achieve this, implement jsonrpc2.Async(ctx) to signal asynchronous handling. This is simpler to use than returning ErrAsyncResponse and calling Respond, and since this is an internal detail we don't need to worry too much about whether it's idiomatic. Add tests that verify both features, for both client and server. Also: - replace req.ID.IsValid with req.IsCall - remove the methodHandler type as we can just use MethodHandler Fixes #26 --- internal/jsonrpc2/conn.go | 82 ++++++++++-------- internal/jsonrpc2/jsonrpc2.go | 7 -- internal/jsonrpc2/jsonrpc2_test.go | 16 ++-- mcp/client.go | 7 +- mcp/conformance_test.go | 6 +- mcp/content.go | 2 +- mcp/mcp_test.go | 133 ++++++++++++++++++++++++++--- mcp/server.go | 13 ++- mcp/shared.go | 24 ++---- mcp/streamable.go | 2 +- mcp/streamable_test.go | 2 +- 11 files changed, 208 insertions(+), 86 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 6f48c9ba..963350e7 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -374,6 +374,46 @@ func (c *Connection) Call(ctx context.Context, method string, params any) *Async return ac } +// Async, signals that the current jsonrpc2 request may be handled +// asynchronously to subsequent requests, when ctx is the request context. +// +// Async must be called at most once on each request's context (and its +// descendants). +func Async(ctx context.Context) { + if r, ok := ctx.Value(asyncKey).(*releaser); ok { + r.release(false) + } +} + +type asyncKeyType struct{} + +var asyncKey = asyncKeyType{} + +// A releaser implements concurrency safe 'releasing' of async requests. (A +// request is released when it is allowed to run concurrent with other +// requests, via a call to [Async].) +type releaser struct { + mu sync.Mutex + ch chan struct{} + released bool +} + +// release closes the associated channel. If soft is set, multiple calls to +// release are allowed. +func (r *releaser) release(soft bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.released { + if !soft { + panic("jsonrpc2.Async called multiple times") + } + } else { + close(r.ch) + r.released = true + } +} + type AsyncCall struct { id ID ready chan struct{} // closed after response has been set @@ -425,28 +465,6 @@ func (ac *AsyncCall) Await(ctx context.Context, result any) error { return json.Unmarshal(ac.response.Result, result) } -// Respond delivers a response to an incoming Call. -// -// Respond must be called exactly once for any message for which a handler -// returns ErrAsyncResponse. It must not be called for any other message. -func (c *Connection) Respond(id ID, result any, err error) error { - var req *incomingRequest - c.updateInFlight(func(s *inFlightState) { - req = s.incomingByID[id] - }) - if req == nil { - return c.internalErrorf("Request not found for ID %v", id) - } - - if err == ErrAsyncResponse { - // Respond is supposed to supply the asynchronous response, so it would be - // confusing to call Respond with an error that promises to call Respond - // again. - err = c.internalErrorf("Respond called with ErrAsyncResponse for %q", req.Method) - } - return c.processResult("Respond", req, result, err) -} - // Cancel cancels the Context passed to the Handle call for the inbound message // with the given ID. // @@ -576,11 +594,6 @@ func (c *Connection) acceptRequest(ctx context.Context, msg *Request, preempter if preempter != nil { result, err := preempter.Preempt(req.ctx, req.Request) - if req.IsCall() && errors.Is(err, ErrAsyncResponse) { - // This request will remain in flight until Respond is called for it. - return - } - if !errors.Is(err, ErrNotHandled) { c.processResult("Preempt", req, result, err) return @@ -655,19 +668,20 @@ func (c *Connection) handleAsync() { continue } - result, err := c.handler.Handle(req.ctx, req.Request) - c.processResult(c.handler, req, result, err) + releaser := &releaser{ch: make(chan struct{})} + ctx := context.WithValue(req.ctx, asyncKey, releaser) + go func() { + defer releaser.release(true) + result, err := c.handler.Handle(ctx, req.Request) + c.processResult(c.handler, req, result, err) + }() + <-releaser.ch } } // processResult processes the result of a request and, if appropriate, sends a response. func (c *Connection) processResult(from any, req *incomingRequest, result any, err error) error { switch err { - case ErrAsyncResponse: - if !req.IsCall() { - return c.internalErrorf("%#v returned ErrAsyncResponse for a %q Request without an ID", from, req.Method) - } - return nil // This request is still in flight, so don't record the result yet. case ErrNotHandled, ErrMethodNotFound: // Add detail describing the unhandled method. err = fmt.Errorf("%w: %q", ErrMethodNotFound, req.Method) diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go index b9c320c8..234e6ee3 100644 --- a/internal/jsonrpc2/jsonrpc2.go +++ b/internal/jsonrpc2/jsonrpc2.go @@ -22,13 +22,6 @@ var ( // If a Handler returns ErrNotHandled, the server replies with // ErrMethodNotFound. ErrNotHandled = errors.New("JSON RPC not handled") - - // ErrAsyncResponse is returned from a handler to indicate it will generate a - // response asynchronously. - // - // ErrAsyncResponse must not be returned for notifications, - // which do not receive responses. - ErrAsyncResponse = errors.New("JSON RPC asynchronous response") ) // Preempter handles messages on a connection before they are queued to the main diff --git a/internal/jsonrpc2/jsonrpc2_test.go b/internal/jsonrpc2/jsonrpc2_test.go index 16a5039b..8c79300c 100644 --- a/internal/jsonrpc2/jsonrpc2_test.go +++ b/internal/jsonrpc2/jsonrpc2_test.go @@ -371,16 +371,14 @@ func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (any, error if err := json.Unmarshal(req.Params, &name); err != nil { return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) } + jsonrpc2.Async(ctx) waitFor := h.waiter(name) - go func() { - select { - case <-waitFor: - h.conn.Respond(req.ID, true, nil) - case <-ctx.Done(): - h.conn.Respond(req.ID, nil, ctx.Err()) - } - }() - return nil, jsonrpc2.ErrAsyncResponse + select { + case <-waitFor: + return true, nil + case <-ctx.Done(): + return nil, ctx.Err() + } default: return nil, jsonrpc2.ErrNotHandled } diff --git a/mcp/client.go b/mcp/client.go index dcabbfd6..344e2f3d 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -328,16 +328,19 @@ func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { } func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { + if req.IsCall() { + jsonrpc2.Async(ctx) + } return handleReceive(ctx, cs, req) } -func (cs *ClientSession) sendingMethodHandler() methodHandler { +func (cs *ClientSession) sendingMethodHandler() MethodHandler { cs.client.mu.Lock() defer cs.client.mu.Unlock() return cs.client.sendingMethodHandler_ } -func (cs *ClientSession) receivingMethodHandler() methodHandler { +func (cs *ClientSession) receivingMethodHandler() MethodHandler { cs.client.mu.Lock() defer cs.client.mu.Unlock() return cs.client.receivingMethodHandler_ diff --git a/mcp/conformance_test.go b/mcp/conformance_test.go index 8e6ea1be..9bd8b8f6 100644 --- a/mcp/conformance_test.go +++ b/mcp/conformance_test.go @@ -183,7 +183,7 @@ func runServerTest(t *testing.T, test *conformanceTest) { return nil, err, false } serverMessages = append(serverMessages, msg) - if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() { + if req, ok := msg.(*jsonrpc.Request); ok && req.IsCall() { // Pair up the next outgoing response with this request. // We assume requests arrive in the same order every time. if len(outResponses) == 0 { @@ -201,8 +201,8 @@ func runServerTest(t *testing.T, test *conformanceTest) { // Synthetic peer interacts with real peer. for _, req := range outRequests { writeMsg(req) - if req.ID.IsValid() { - // A request (as opposed to a notification). Wait for the response. + if req.IsCall() { + // A call (as opposed to a notification). Wait for the response. res, err, ok := nextResponse() if err != nil { t.Fatalf("reading server messages failed: %v", err) diff --git a/mcp/content.go b/mcp/content.go index f8777154..108b0271 100644 --- a/mcp/content.go +++ b/mcp/content.go @@ -253,7 +253,7 @@ func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, e func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error) { if wire == nil { - return nil, fmt.Errorf("content wire is nil") + return nil, fmt.Errorf("nil content") } if allow != nil && !allow[wire.Type] { return nil, fmt.Errorf("invalid content type %q", wire.Type) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 4a10d304..159f878f 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -17,6 +17,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "testing" "time" @@ -549,31 +550,47 @@ func errorCode(err error) int64 { // // The caller should cancel either the client connection or server connection // when the connections are no longer needed. -func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *ClientSession) { +func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *ServerSession) { + return basicClientServerConnection(t, nil, nil, config) +} + +// basicClientServerConnection creates a basic connection between client and +// server. If either client or server is nil, empty implementations are used. +// +// The provided function may be used to configure features on the resulting +// server, prior to connection. +// +// The caller should cancel either the client connection or server connection +// when the connections are no longer needed. +func basicClientServerConnection(t *testing.T, client *Client, server *Server, config func(*Server)) (*ClientSession, *ServerSession) { t.Helper() ctx := context.Background() ct, st := NewInMemoryTransports() - s := NewServer(testImpl, nil) + if server == nil { + server = NewServer(testImpl, nil) + } if config != nil { - config(s) + config(server) } - ss, err := s.Connect(ctx, st, nil) + ss, err := server.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } - c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, ct, nil) + if client == nil { + client = NewClient(testImpl, nil) + } + cs, err := client.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } - return ss, cs + return cs, ss } func TestServerClosing(t *testing.T) { - cc, cs := basicConnection(t, func(s *Server) { + cs, ss := basicConnection(t, func(s *Server) { AddTool(s, greetTool(), sayHi) }) defer cs.Close() @@ -593,7 +610,7 @@ func TestServerClosing(t *testing.T) { }); err != nil { t.Fatalf("after connecting: %v", err) } - cc.Close() + ss.Close() wg.Wait() if _, err := cs.CallTool(ctx, &CallToolParams{ Name: "greet", @@ -656,7 +673,7 @@ func TestCancellation(t *testing.T) { } return nil, nil } - _, cs := basicConnection(t, func(s *Server) { + cs, _ := basicConnection(t, func(s *Server) { AddTool(s, &Tool{Name: "slow"}, slowRequest) }) defer cs.Close() @@ -940,7 +957,7 @@ func TestKeepAliveFailure(t *testing.T) { func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { // Adding the same tool pointer twice should not panic and should not // produce duplicates in the server's tool list. - _, cs := basicConnection(t, func(s *Server) { + cs, _ := basicConnection(t, func(s *Server) { // Use two distinct Tool instances with the same name but different // descriptions to ensure the second replaces the first // This case was written specifically to reproduce a bug where duplicate tools where causing jsonschema errors @@ -972,4 +989,98 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { } } +func TestSynchronousNotifications(t *testing.T) { + var toolsChanged atomic.Bool + clientOpts := &ClientOptions{ + ToolListChangedHandler: func(ctx context.Context, req *ClientRequest[*ToolListChangedParams]) { + toolsChanged.Store(true) + }, + CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + if !toolsChanged.Load() { + return nil, fmt.Errorf("didn't get a tools changed notification") + } + // TODO(rfindley): investigate the error returned from this test if + // CreateMessageResult is new(CreateMessageResult): it's a mysterious + // unmarshalling error that we should improve. + return &CreateMessageResult{Content: &TextContent{}}, nil + }, + } + client := NewClient(testImpl, clientOpts) + + var rootsChanged atomic.Bool + serverOpts := &ServerOptions{ + RootsListChangedHandler: func(_ context.Context, req *ServerRequest[*RootsListChangedParams]) { + rootsChanged.Store(true) + }, + } + server := NewServer(testImpl, serverOpts) + cs, ss := basicClientServerConnection(t, client, server, func(s *Server) { + AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + if !rootsChanged.Load() { + return nil, fmt.Errorf("didn't get root change notification") + } + return new(CallToolResult), nil + }) + }) + + t.Run("from client", func(t *testing.T) { + client.AddRoots(&Root{Name: "myroot", URI: "file://foo"}) + res, err := cs.CallTool(context.Background(), &CallToolParams{Name: "tool"}) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } + if res.IsError { + t.Errorf("tool error: %v", res.Content[0].(*TextContent).Text) + } + }) + + t.Run("from server", func(t *testing.T) { + server.RemoveTools("tool") + if _, err := ss.CreateMessage(context.Background(), new(CreateMessageParams)); err != nil { + t.Errorf("CreateMessage failed: %v", err) + } + }) +} + +func TestNoDistributedDeadlock(t *testing.T) { + // This test verifies that calls are asynchronous, and so it's not possible + // to have a distributed deadlock. + // + // The setup creates potential deadlock for both the client and server: the + // client sends a call to tool1, which itself calls createMessage, which in + // turn calls tool2, which calls ping. + // + // If the server were not asynchronous, the call to tool2 would hang. If the + // client were not asynchronous, the call to ping would hang. + // + // Such a scenario is unlikely in practice, but is still theoretically + // possible, and in any case making tool calls asynchronous by default + // delegates synchronization to the user. + clientOpts := &ClientOptions{ + CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + req.Session.CallTool(ctx, &CallToolParams{Name: "tool2"}) + return &CreateMessageResult{Content: &TextContent{}}, nil + }, + } + client := NewClient(testImpl, clientOpts) + cs, _ := basicClientServerConnection(t, client, nil, func(s *Server) { + AddTool(s, &Tool{Name: "tool1"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + req.Session.CreateMessage(ctx, new(CreateMessageParams)) + return new(CallToolResult), nil + }) + AddTool(s, &Tool{Name: "tool2"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + req.Session.Ping(ctx, nil) + return new(CallToolResult), nil + }) + }) + defer cs.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := cs.CallTool(ctx, &CallToolParams{Name: "tool1"}); err != nil { + // should not deadlock + t.Fatalf("CallTool failed: %v", err) + } +} + var testImpl = &Implementation{Name: "test", Version: "v1.0.0"} diff --git a/mcp/server.go b/mcp/server.go index d118464a..4061980d 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -779,14 +779,14 @@ func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return cli func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { return serverMethodInfos } -func (ss *ServerSession) sendingMethodHandler() methodHandler { +func (ss *ServerSession) sendingMethodHandler() MethodHandler { s := ss.server s.mu.Lock() defer s.mu.Unlock() return s.sendingMethodHandler_ } -func (ss *ServerSession) receivingMethodHandler() methodHandler { +func (ss *ServerSession) receivingMethodHandler() MethodHandler { s := ss.server s.mu.Lock() defer s.mu.Unlock() @@ -801,6 +801,7 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, ss.mu.Lock() initialized := ss.state.InitializedParams != nil ss.mu.Unlock() + // From the spec: // "The client SHOULD NOT send requests other than pings before the server // has responded to the initialize request." @@ -811,6 +812,14 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) } } + + // modelcontextprotocol/go-sdk#26: handle calls asynchronously, and + // notifications synchronously, except for 'initialize' which shouldn't be + // asynchronous to other + if req.IsCall() && req.Method != methodInitialize { + jsonrpc2.Async(ctx) + } + // For the streamable transport, we need the request ID to correlate // server->client calls and notifications to the incoming request from which // they originated. See [idContextKey] for details. diff --git a/mcp/shared.go b/mcp/shared.go index ca062214..608e2aaf 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -38,12 +38,6 @@ var supportedProtocolVersions = []string{ // For notifications, both must be nil. type MethodHandler func(ctx context.Context, method string, req Request) (result Result, err error) -// A methodHandler is a MethodHandler[Session] for some session. -// We need to give up type safety here, or we will end up with a type cycle somewhere -// else. For example, if Session.methodHandler returned a MethodHandler[Session], -// the compiler would complain. -type methodHandler any // MethodHandler[*ClientSession] | MethodHandler[*ServerSession] - // A Session is either a [ClientSession] or a [ServerSession]. type Session interface { // ID returns the session ID, or the empty string if there is none. @@ -51,8 +45,8 @@ type Session interface { sendingMethodInfos() map[string]methodInfo receivingMethodInfos() map[string]methodInfo - sendingMethodHandler() methodHandler - receivingMethodHandler() methodHandler + sendingMethodHandler() MethodHandler + receivingMethodHandler() MethodHandler getConn() *jsonrpc2.Connection } @@ -95,13 +89,13 @@ func orZero[T any, P *U, U any](p P) T { } func handleNotify(ctx context.Context, method string, req Request) error { - mh := req.GetSession().sendingMethodHandler().(MethodHandler) + mh := req.GetSession().sendingMethodHandler() _, err := mh(ctx, method, req) return err } func handleSend[R Result](ctx context.Context, method string, req Request) (R, error) { - mh := req.GetSession().sendingMethodHandler().(MethodHandler) + mh := req.GetSession().sendingMethodHandler() // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. res, err := mh(ctx, method, req) if err != nil { @@ -118,7 +112,7 @@ func defaultReceivingMethodHandler[S Session](ctx context.Context, method string // This can be called from user code, with an arbitrary value for method. return nil, jsonrpc2.ErrNotHandled } - return info.handleMethod.(MethodHandler)(ctx, method, req) + return info.handleMethod(ctx, method, req) } func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Request) (Result, error) { @@ -131,7 +125,7 @@ func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Requ return nil, fmt.Errorf("handling '%s': %w", jreq.Method, err) } - mh := session.receivingMethodHandler().(MethodHandler) + mh := session.receivingMethodHandler() req := info.newRequest(session, params) // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. res, err := mh(ctx, jreq.Method, req) @@ -154,10 +148,10 @@ func checkRequest(req *jsonrpc.Request, infos map[string]methodInfo) (methodInfo if !ok { return methodInfo{}, fmt.Errorf("%w: %q unsupported", jsonrpc2.ErrNotHandled, req.Method) } - if info.flags¬ification != 0 && req.ID.IsValid() { + if info.flags¬ification != 0 && req.IsCall() { return methodInfo{}, fmt.Errorf("%w: unexpected id for %q", jsonrpc2.ErrInvalidRequest, req.Method) } - if info.flags¬ification == 0 && !req.ID.IsValid() { + if info.flags¬ification == 0 && !req.IsCall() { return methodInfo{}, fmt.Errorf("%w: missing id for %q", jsonrpc2.ErrInvalidRequest, req.Method) } // missingParamsOK is checked here to catch the common case where "params" is @@ -182,7 +176,7 @@ type methodInfo struct { newRequest func(Session, Params) Request // Run the code when a call to the method is received. // Used on the receive side. - handleMethod methodHandler + handleMethod MethodHandler // Create a pointer to a Result struct. // Used on the send side. newResult func() Result diff --git a/mcp/streamable.go b/mcp/streamable.go index 9ae20c02..1ecf201f 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -509,7 +509,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques http.Error(w, err.Error(), http.StatusBadRequest) return } - if req.ID.IsValid() { + if req.IsCall() { requests[req.ID] = struct{}{} } } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 4181303f..e0b00cc6 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -689,7 +689,7 @@ func TestStreamableServerTransport(t *testing.T) { defer wg.Done() for m := range out { - if req, ok := m.(*jsonrpc.Request); ok && req.ID.IsValid() { + if req, ok := m.(*jsonrpc.Request); ok && req.IsCall() { // Encountered a server->client request. We should have a // response queued. Otherwise, we may deadlock. mu.Lock() From 46ba813860e62405b3534f00e9eaddb1b63d7069 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Tue, 19 Aug 2025 14:58:01 -0400 Subject: [PATCH 108/221] .github: add a staticcheck action (#326) Add a staticcheck step to the lint job of the Test workflow, and fix reported diagnostics. --- .github/workflows/test.yml | 4 ++++ internal/jsonrpc2/conn.go | 8 ++++---- mcp/client.go | 4 ---- mcp/streamable.go | 5 +---- mcp/streamable_test.go | 4 ++-- 5 files changed, 11 insertions(+), 14 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5eb2dacd..d8e2d31e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -30,6 +30,10 @@ jobs: echo "All Go files are properly formatted" - name: Run Go vet run: go vet ./... + - name: Run staticcheck + uses: dominikh/staticcheck-action@v1 + with: + version: "latest" test: runs-on: ubuntu-latest diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 963350e7..6bacfa7e 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -719,10 +719,10 @@ func (c *Connection) processResult(from any, req *incomingRequest, result any, e } else if err != nil { err = fmt.Errorf("%w: %q notification failed: %v", ErrInternal, req.Method, err) } - if err != nil { - // TODO: can/should we do anything with this error beyond writing it to the event log? - // (Is this the right label to attach to the log?) - } + } + if err != nil { + // TODO: can/should we do anything with this error beyond writing it to the event log? + // (Is this the right label to attach to the log?) } // Cancel the request to free any associated resources. diff --git a/mcp/client.go b/mcp/client.go index 344e2f3d..d1d17502 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -178,10 +178,6 @@ type clientSessionState struct { func (cs *ClientSession) InitializeResult() *InitializeResult { return cs.state.InitializeResult } -func (cs *ClientSession) setConn(c Connection) { - cs.mcpConn = c -} - func (cs *ClientSession) ID() string { if c, ok := cs.mcpConn.(hasSessionID); ok { return c.SessionID() diff --git a/mcp/streamable.go b/mcp/streamable.go index 1ecf201f..526ee515 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -121,7 +121,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque var transport *StreamableServerTransport if id := req.Header.Get(sessionIDHeader); id != "" { h.mu.Lock() - transport, _ = h.transports[id] + transport = h.transports[id] h.mu.Unlock() if transport == nil { http.Error(w, "session not found", http.StatusNotFound) @@ -288,9 +288,6 @@ type streamableServerConn struct { jsonResponse bool eventStore EventStore - lastStreamID atomic.Int64 // last stream ID used, atomically incremented - - opts StreamableServerTransportOptions incoming chan jsonrpc.Message // messages from the client to the server done chan struct{} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index e0b00cc6..52c47998 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -539,10 +539,10 @@ func TestStreamableServerTransport(t *testing.T) { }, { name: "background", - tool: func(t *testing.T, ctx context.Context, ss *ServerSession) { + tool: func(t *testing.T, _ context.Context, ss *ServerSession) { // Perform operations on a background context, and ensure the client // receives it. - ctx = context.Background() + ctx := context.Background() if err := ss.NotifyProgress(ctx, &ProgressNotificationParams{}); err != nil { t.Errorf("Notify failed: %v", err) } From 732b97fdaf705732af4c63247e0c28362bd927ed Mon Sep 17 00:00:00 2001 From: cryo Date: Wed, 20 Aug 2025 05:43:26 +0800 Subject: [PATCH 109/221] mcp: add syntax and scheme validation to AddResourceTemplate (#253) Add template validation and scheme check in `AddResourceTemplate`. --- mcp/server.go | 15 ++++++++++++++- mcp/server_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/mcp/server.go b/mcp/server.go index 4061980d..c3fbd9e3 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -22,6 +22,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/util" "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/yosida95/uritemplate/v3" ) const DefaultPageSize = 1000 @@ -229,7 +230,19 @@ func (s *Server) RemoveResources(uris ...string) { func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) { s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, func() bool { - // TODO: check template validity. + // Validate the URI template syntax + _, err := uritemplate.New(t.URITemplate) + if err != nil { + panic(fmt.Errorf("URI template %q is invalid: %w", t.URITemplate, err)) + } + // Ensure the URI template has a valid scheme + u, err := url.Parse(t.URITemplate) + if err != nil { + panic(err) // url.Parse includes the URI in the error + } + if !u.IsAbs() { + panic(fmt.Errorf("URI template %q needs a scheme", t.URITemplate)) + } s.resourceTemplates.add(&serverResourceTemplate{t, h}) return true }) diff --git a/mcp/server_test.go b/mcp/server_test.go index adadc9c3..39a4cdb4 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -373,6 +373,42 @@ func TestServerCapabilities(t *testing.T) { } } +func TestServerAddResourceTemplate(t *testing.T) { + tests := []struct { + name string + template string + expectPanic bool + }{ + {"ValidFileTemplate", "file:///{a}/{b}", false}, + {"ValidCustomScheme", "myproto:///{a}", false}, + {"MissingScheme1", "://example.com/{path}", true}, + {"MissingScheme2", "/api/v1/users/{id}", true}, + {"EmptyVariable", "file:///{}/{b}", true}, + {"UnclosedVariable", "file:///{a", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rt := ResourceTemplate{URITemplate: tt.template} + + defer func() { + if r := recover(); r != nil { + if !tt.expectPanic { + t.Errorf("%s: unexpected panic: %v", tt.name, r) + } + } else { + if tt.expectPanic { + t.Errorf("%s: expected panic but did not panic", tt.name) + } + } + }() + + s := NewServer(testImpl, nil) + s.AddResourceTemplate(&rt, nil) + }) + } +} + // TestServerSessionkeepaliveCancelOverwritten is to verify that `ServerSession.keepaliveCancel` is assigned exactly once, // ensuring that only a single goroutine is responsible for the session's keepalive ping mechanism. func TestServerSessionkeepaliveCancelOverwritten(t *testing.T) { From 0eb03e0d760a48a5bcb3bb1071c733fb29fae2bf Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 14 Aug 2025 23:49:25 +0000 Subject: [PATCH 110/221] mcp: factor out reusable functions from TestStreamableServerTransport The next CL will test stateless and distributable server transport configurations, using the HTTP testing strategy of TestStreamableServerTransport. --- mcp/server.go | 2 + mcp/streamable_test.go | 367 ++++++++++++++++++++--------------------- 2 files changed, 182 insertions(+), 187 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index c3fbd9e3..88021336 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -355,6 +355,8 @@ func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParam if !ok { return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, req.Params.Name) } + // TODO: if handler returns nil content, it will serialize as null. + // Add a test and fix. return st.handler(ctx, req) } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 52c47998..7c77938e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -382,30 +382,31 @@ func readNotifications(t *testing.T, ctx context.Context, notifications chan str } } +// JSON-RPC message constructors. +func req(id int64, method string, params any) *jsonrpc.Request { + r := &jsonrpc.Request{ + Method: method, + Params: mustMarshal(params), + } + if id > 0 { + r.ID = jsonrpc2.Int64ID(id) + } + return r +} + +func resp(id int64, result any, err error) *jsonrpc.Response { + return &jsonrpc.Response{ + ID: jsonrpc2.Int64ID(id), + Result: mustMarshal(result), + Error: err, + } +} + func TestStreamableServerTransport(t *testing.T) { // This test checks detailed behavior of the streamable server transport, by // faking the behavior of a streamable client using a sequence of HTTP // requests. - // JSON-RPC message constructors. - req := func(id int64, method string, params any) *jsonrpc.Request { - r := &jsonrpc.Request{ - Method: method, - Params: mustMarshal(t, params), - } - if id > 0 { - r.ID = jsonrpc2.Int64ID(id) - } - return r - } - resp := func(id int64, result any, err error) *jsonrpc.Response { - return &jsonrpc.Response{ - ID: jsonrpc2.Int64ID(id), - Result: mustMarshal(t, result), - Error: err, - } - } - // Predefined steps, to avoid repetition below. initReq := req(1, methodInitialize, &InitializeParams{}) initResp := resp(1, &InitializeResult{ @@ -422,21 +423,23 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{initReq}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{initResp}, + wantSessionID: true, } initialized := streamableRequest{ method: "POST", messages: []jsonrpc.Message{initializedMsg}, wantStatusCode: http.StatusAccepted, + wantSessionID: false, // TODO: should this be true? } tests := []struct { - name string - tool func(*testing.T, context.Context, *ServerSession) - steps []streamableRequest + name string + tool func(*testing.T, context.Context, *ServerSession) + requests []streamableRequest // http requests }{ { name: "basic", - steps: []streamableRequest{ + requests: []streamableRequest{ initialize, initialized, { @@ -444,12 +447,13 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{}, nil)}, + wantSessionID: true, }, }, }, { name: "accept headers", - steps: []streamableRequest{ + requests: []streamableRequest{ initialize, initialized, // Test various accept headers. @@ -458,12 +462,14 @@ func TestStreamableServerTransport(t *testing.T) { accept: []string{"text/plain", "application/*"}, messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusBadRequest, // missing text/event-stream + wantSessionID: false, }, { method: "POST", accept: []string{"text/event-stream"}, messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusBadRequest, // missing application/json + wantSessionID: false, }, { method: "POST", @@ -471,6 +477,7 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, + wantSessionID: true, }, { method: "POST", @@ -478,6 +485,7 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, + wantSessionID: true, }, }, }, @@ -489,7 +497,7 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Notify failed: %v", err) } }, - steps: []streamableRequest{ + requests: []streamableRequest{ initialize, initialized, { @@ -502,6 +510,7 @@ func TestStreamableServerTransport(t *testing.T) { req(0, "notifications/progress", &ProgressNotificationParams{}), resp(2, &CallToolResult{}, nil), }, + wantSessionID: true, }, }, }, @@ -513,7 +522,7 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Call failed: %v", err) } }, - steps: []streamableRequest{ + requests: []streamableRequest{ initialize, initialized, { @@ -523,6 +532,7 @@ func TestStreamableServerTransport(t *testing.T) { resp(1, &ListRootsResult{}, nil), }, wantStatusCode: http.StatusAccepted, + wantSessionID: false, }, { method: "POST", @@ -534,6 +544,7 @@ func TestStreamableServerTransport(t *testing.T) { req(1, "roots/list", &ListRootsParams{}), resp(2, &CallToolResult{}, nil), }, + wantSessionID: true, }, }, }, @@ -554,7 +565,7 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Notify failed: %v", err) } }, - steps: []streamableRequest{ + requests: []streamableRequest{ initialize, initialized, { @@ -564,6 +575,7 @@ func TestStreamableServerTransport(t *testing.T) { resp(1, &ListRootsResult{}, nil), }, wantStatusCode: http.StatusAccepted, + wantSessionID: false, }, { method: "GET", @@ -574,6 +586,7 @@ func TestStreamableServerTransport(t *testing.T) { req(0, "notifications/progress", &ProgressNotificationParams{}), req(1, "roots/list", &ListRootsParams{}), }, + wantSessionID: true, }, { method: "POST", @@ -584,12 +597,13 @@ func TestStreamableServerTransport(t *testing.T) { wantMessages: []jsonrpc.Message{ resp(2, &CallToolResult{}, nil), }, + wantSessionID: true, }, }, }, { name: "errors", - steps: []streamableRequest{ + requests: []streamableRequest{ { method: "PUT", wantStatusCode: http.StatusMethodNotAllowed, @@ -615,6 +629,7 @@ func TestStreamableServerTransport(t *testing.T) { wantMessages: []jsonrpc.Message{resp(2, nil, &jsonrpc2.WireError{ Message: `method "tools/call" is invalid during session initialization`, })}, + wantSessionID: true, // TODO: this is probably wrong; we don't have a valid session }, }, }, @@ -636,118 +651,127 @@ func TestStreamableServerTransport(t *testing.T) { handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) defer handler.closeAll() - httpServer := httptest.NewServer(handler) - defer httpServer.Close() + testStreamableHandler(t, handler, test.requests) + }) + } +} - // blocks records request blocks by jsonrpc. ID. - // - // When an OnRequest step is encountered, it waits on the corresponding - // block. When a request with that ID is received, the block is closed. - var mu sync.Mutex - blocks := make(map[int64]chan struct{}) - for _, step := range test.steps { - if step.onRequest > 0 { - blocks[step.onRequest] = make(chan struct{}) - } - } +func testStreamableHandler(t *testing.T, handler http.Handler, requests []streamableRequest) { + httpServer := httptest.NewServer(handler) + defer httpServer.Close() - // signal when all synchronous requests have executed, so we can fail - // async requests that are blocked. - syncRequestsDone := make(chan struct{}) + // blocks records request blocks by jsonrpc. ID. + // + // When an OnRequest step is encountered, it waits on the corresponding + // block. When a request with that ID is received, the block is closed. + var mu sync.Mutex + blocks := make(map[int64]chan struct{}) + for _, req := range requests { + if req.onRequest > 0 { + blocks[req.onRequest] = make(chan struct{}) + } + } - // To avoid complicated accounting for session ID, just set the first - // non-empty session ID from a response. - var sessionID atomic.Value - sessionID.Store("") + // signal when all synchronous requests have executed, so we can fail + // async requests that are blocked. + syncRequestsDone := make(chan struct{}) + + // To avoid complicated accounting for session ID, just set the first + // non-empty session ID from a response. + var sessionID atomic.Value + sessionID.Store("") + + // doStep executes a single step. + doStep := func(t *testing.T, i int, request streamableRequest) { + if request.onRequest > 0 { + // Block the step until we've received the server->client request. + mu.Lock() + block := blocks[request.onRequest] + mu.Unlock() + select { + case <-block: + case <-syncRequestsDone: + t.Errorf("after all sync requests are complete, request still blocked on %d", request.onRequest) + return + } + } - // doStep executes a single step. - doStep := func(t *testing.T, step streamableRequest) { - if step.onRequest > 0 { - // Block the step until we've received the server->client request. + // Collect messages received during this request, unblock other steps + // when requests are received. + var got []jsonrpc.Message + out := make(chan jsonrpc.Message) + // Cancel the step if we encounter a request that isn't going to be + // handled. + ctx, cancel := context.WithCancel(context.Background()) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + for m := range out { + if req, ok := m.(*jsonrpc.Request); ok && req.IsCall() { + // Encountered a server->client request. We should have a + // response queued. Otherwise, we may deadlock. mu.Lock() - block := blocks[step.onRequest] - mu.Unlock() - select { - case <-block: - case <-syncRequestsDone: - t.Errorf("after all sync requests are complete, request still blocked on %d", step.onRequest) - return + if block, ok := blocks[req.ID.Raw().(int64)]; ok { + close(block) + } else { + t.Errorf("no queued response for %v", req.ID) + cancel() } + mu.Unlock() } + got = append(got, m) + if request.closeAfter > 0 && len(got) == request.closeAfter { + cancel() + } + } + }() - // Collect messages received during this request, unblock other steps - // when requests are received. - var got []jsonrpc.Message - out := make(chan jsonrpc.Message) - // Cancel the step if we encounter a request that isn't going to be - // handled. - ctx, cancel := context.WithCancel(context.Background()) - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - - for m := range out { - if req, ok := m.(*jsonrpc.Request); ok && req.IsCall() { - // Encountered a server->client request. We should have a - // response queued. Otherwise, we may deadlock. - mu.Lock() - if block, ok := blocks[req.ID.Raw().(int64)]; ok { - close(block) - } else { - t.Errorf("no queued response for %v", req.ID) - cancel() - } - mu.Unlock() - } - got = append(got, m) - if step.closeAfter > 0 && len(got) == step.closeAfter { - cancel() - } - } - }() - - gotSessionID, gotStatusCode, err := step.do(ctx, httpServer.URL, sessionID.Load().(string), out) + gotSessionID, gotStatusCode, err := request.do(ctx, httpServer.URL, sessionID.Load().(string), out) - // Don't fail on cancelled requests: error (if any) is handled - // elsewhere. - if err != nil && ctx.Err() == nil { - t.Fatal(err) - } + // Don't fail on cancelled requests: error (if any) is handled + // elsewhere. + if err != nil && ctx.Err() == nil { + t.Fatal(err) + } - if gotStatusCode != step.wantStatusCode { - t.Errorf("got status %d, want %d", gotStatusCode, step.wantStatusCode) - } - wg.Wait() + if gotStatusCode != request.wantStatusCode { + t.Errorf("request #%d: got status %d, want %d", i, gotStatusCode, request.wantStatusCode) + } + if got := gotSessionID != ""; got != request.wantSessionID { + t.Errorf("request #%d: got session id: %t, want %t", i, got, request.wantSessionID) + } + wg.Wait() - transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) - if diff := cmp.Diff(step.wantMessages, got, transform); diff != "" { - t.Errorf("received unexpected messages (-want +got):\n%s", diff) - } - sessionID.CompareAndSwap("", gotSessionID) + if !request.ignoreResponse { + transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) + if diff := cmp.Diff(request.wantMessages, got, transform); diff != "" { + t.Errorf("received unexpected messages (-want +got):\n%s", diff) } + } + sessionID.CompareAndSwap("", gotSessionID) + } - var wg sync.WaitGroup - for _, step := range test.steps { - if step.async || step.onRequest > 0 { - wg.Add(1) - go func() { - defer wg.Done() - doStep(t, step) - }() - } else { - doStep(t, step) - } - } + var wg sync.WaitGroup + for i, request := range requests { + if request.async || request.onRequest > 0 { + wg.Add(1) + go func() { + defer wg.Done() + doStep(t, i, request) + }() + } else { + doStep(t, i, request) + } + } - // Fail any blocked responses if they weren't needed by a synchronous - // request. - close(syncRequestsDone) + // Fail any blocked responses if they weren't needed by a synchronous + // request. + close(syncRequestsDone) - wg.Wait() - }) - } + wg.Wait() } // A streamableRequest describes a single streamable HTTP request, consisting @@ -768,13 +792,15 @@ type streamableRequest struct { async bool // Request attributes - method string // HTTP request method + method string // HTTP request method (required) accept []string // if non-empty, the Accept header to use; otherwise the default header is used messages []jsonrpc.Message // messages to send closeAfter int // if nonzero, close after receiving this many messages wantStatusCode int // expected status code + ignoreResponse bool // if set, don't check the response messages wantMessages []jsonrpc.Message // expected messages to receive + wantSessionID bool // whether or not a session ID is expected in the response } // streamingRequest makes a request to the given streamable server with the @@ -840,7 +866,8 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, newSessionID := resp.Header.Get("Mcp-Session-Id") - if strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") { + contentType := resp.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "text/event-stream") { for evt, err := range scanEvents(resp.Body) { if err != nil { return newSessionID, resp.StatusCode, fmt.Errorf("reading events: %v", err) @@ -853,7 +880,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, } out <- msg } - } else if strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { + } else if strings.HasPrefix(contentType, "application/json") { data, err := io.ReadAll(resp.Body) if err != nil { return newSessionID, resp.StatusCode, fmt.Errorf("reading json body: %w", err) @@ -868,14 +895,13 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, return newSessionID, resp.StatusCode, nil } -func mustMarshal(t *testing.T, v any) json.RawMessage { +func mustMarshal(v any) json.RawMessage { if v == nil { return nil } - t.Helper() data, err := json.Marshal(v) if err != nil { - t.Fatal(err) + panic(err) } return data } @@ -886,7 +912,7 @@ func TestStreamableClientTransportApplicationJSON(t *testing.T) { resp := func(id int64, result any, err error) *jsonrpc.Response { return &jsonrpc.Response{ ID: jsonrpc2.Int64ID(id), - Result: mustMarshal(t, result), + Result: mustMarshal(result), Error: err, } } @@ -970,13 +996,10 @@ func TestEventID(t *testing.T) { } func TestStreamableStateless(t *testing.T) { - // Test stateless mode behavior - ctx := context.Background() - // This version of sayHi doesn't make a ping request (we can't respond to // that request from our client). - sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResultFor[any], error) { - return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil + sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResult, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil } server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) @@ -985,57 +1008,27 @@ func TestStreamableStateless(t *testing.T) { handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ GetSessionID: func() string { return "" }, }) - httpServer := httptest.NewServer(handler) - defer httpServer.Close() - checkRequest := func(body string) { - // Verify we can call tools/list directly without initialization in stateless mode - req, err := http.NewRequestWithContext(ctx, http.MethodPost, httpServer.URL, strings.NewReader(body)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - // Verify that no session ID header is returned in stateless mode - if sessionID := resp.Header.Get(sessionIDHeader); sessionID != "" { - t.Errorf("%s = %s, want no session ID header", sessionIDHeader, sessionID) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("Status code = %d; want successful response", resp.StatusCode) - } - - var events []Event - for event, err := range scanEvents(resp.Body) { - if err != nil { - t.Fatal(err) - } - events = append(events, event) - } - if len(events) != 1 { - t.Fatalf("got %d SSE events, want 1; events: %v", len(events), events) - } - msg, err := jsonrpc.DecodeMessage(events[0].Data) - if err != nil { - t.Fatal(err) - } - jsonResp, ok := msg.(*jsonrpc.Response) - if !ok { - t.Errorf("event is %T, want response", jsonResp) - } - if jsonResp.Error != nil { - t.Errorf("request failed: %v", jsonResp.Error) - } + requests := []streamableRequest{ + { + method: "POST", + wantStatusCode: http.StatusOK, + messages: []jsonrpc.Message{req(1, "tools/list", struct{}{})}, + ignoreResponse: true, + wantSessionID: false, + }, + { + method: "POST", + wantStatusCode: http.StatusOK, + messages: []jsonrpc.Message{ + req(2, "tools/call", &CallToolParams{Name: "greet", Arguments: hiParams{Name: "World"}}), + }, + wantMessages: []jsonrpc.Message{ + resp(2, &CallToolResult{Content: []Content{&TextContent{Text: "hi World"}}}, nil), + }, + wantSessionID: false, + }, } - checkRequest(`{"jsonrpc":"2.0","method":"tools/list","id":1,"params":{}}`) - - // Verify we can make another request without session ID - checkRequest(`{"jsonrpc":"2.0","method":"tools/call","id":2,"params":{"name":"greet","arguments":{"name":"World"}}}`) + testStreamableHandler(t, handler, requests) } From 98e67fdf629f3fa5b7debc812815149ea85dca9a Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 15 Aug 2025 03:13:53 +0000 Subject: [PATCH 111/221] mcp: improvements for 'stateless' streamable servers; 'distributed' mode Several improvements for the stateless streamable mode, plus support for a 'distributed' (or rather, distributable) version of the stateless server. - Add a 'Stateless' option to StreamableHTTPOptions and StreamableServerTransport, which controls stateless behavior. GetSessionID may still return a non-empty session ID. - Audit validation of stateless mode to allow requests with a session id. Propagate this session ID to the temporary connection. - Peek at requests to allow 'initialize' requests to go through to the session, so that version negotiation can occur (FIXME: add tests). Fixes #284 For #148 --- internal/jsonrpc2/conn.go | 18 +++++- internal/jsonrpc2/wire.go | 11 ++++ mcp/streamable.go | 125 ++++++++++++++++++++++++++++++-------- mcp/streamable_test.go | 82 ++++++++++++++++++++++--- mcp/transport.go | 4 +- 5 files changed, 203 insertions(+), 37 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 6bacfa7e..537be47a 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -739,7 +739,23 @@ func (c *Connection) processResult(from any, req *incomingRequest, result any, e // write is used by all things that write outgoing messages, including replies. // it makes sure that writes are atomic func (c *Connection) write(ctx context.Context, msg Message) error { - err := c.writer.Write(ctx, msg) + var err error + // Fail writes immediately if the connection is shutting down. + // + // TODO(rfindley): should we allow cancellation notifations through? It could + // be the case that writes can still succeed. + c.updateInFlight(func(s *inFlightState) { + err = s.shuttingDown(ErrServerClosing) + }) + if err == nil { + err = c.writer.Write(ctx, msg) + } + + // For rejected requests, we don't set the writeErr (which would break the + // connection). They can just be returned to the caller. + if errors.Is(err, ErrRejected) { + return err + } if err != nil && ctx.Err() == nil { // The call to Write failed, and since ctx.Err() is nil we can't attribute diff --git a/internal/jsonrpc2/wire.go b/internal/jsonrpc2/wire.go index b143dcd3..8be2872e 100644 --- a/internal/jsonrpc2/wire.go +++ b/internal/jsonrpc2/wire.go @@ -37,6 +37,17 @@ var ( ErrServerClosing = NewError(-32004, "server is closing") // ErrClientClosing is a dummy error returned for calls initiated while the client is closing. ErrClientClosing = NewError(-32003, "client is closing") + + // The following errors have special semantics for MCP transports + + // ErrRejected may be wrapped to return errors from calls to Writer.Write + // that signal that the request was rejected by the transport layer as + // invalid. + // + // Such failures do not indicate that the connection is broken, but rather + // should be returned to the caller to indicate that the specific request is + // invalid in the current context. + ErrRejected = NewError(-32004, "rejected by transport") ) const wireVersion = "2.0" diff --git a/mcp/streamable.go b/mcp/streamable.go index 526ee515..789c7a0b 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -38,18 +38,29 @@ type StreamableHTTPHandler struct { getServer func(*http.Request) *Server opts StreamableHTTPOptions - mu sync.Mutex + mu sync.Mutex + // TODO: we should store the ServerSession along with the transport, because + // we need to cancel keepalive requests when closing the transport. transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) } // StreamableHTTPOptions configures the StreamableHTTPHandler. type StreamableHTTPOptions struct { // GetSessionID provides the next session ID to use for an incoming request. + // If nil, a default randomly generated ID will be used. // - // If GetSessionID returns an empty string, the session is 'stateless', - // meaning it is not persisted and no session validation is performed. + // As a special case, if GetSessionID returns the empty string, the + // Mcp-Session-Id header will not be set. GetSessionID func() string + // Stateless controls whether the session is 'stateless'. + // + // A stateless server does not validate the Mcp-Session-Id header, and uses a + // temporary session with default initialization parameters. Any + // server->client request is rejected immediately as there's no way for the + // client to respond. + Stateless bool + // TODO: support session retention (?) // jsonResponse is forwarded to StreamableServerTransport.jsonResponse. @@ -118,36 +129,39 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } + sessionID := req.Header.Get(sessionIDHeader) var transport *StreamableServerTransport - if id := req.Header.Get(sessionIDHeader); id != "" { + if sessionID != "" { h.mu.Lock() - transport = h.transports[id] + transport, _ = h.transports[sessionID] h.mu.Unlock() - if transport == nil { + if transport == nil && !h.opts.Stateless { + // In stateless mode we allow a missing transport. + // + // A synthetic transport will be created below for the transient session. http.Error(w, "session not found", http.StatusNotFound) return } } - // TODO(rfindley): simplify the locking so that each request has only one - // critical section. if req.Method == http.MethodDelete { - if transport == nil { - // => Mcp-Session-Id was not set; else we'd have returned NotFound above. + if sessionID == "" { http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) return } - h.mu.Lock() - delete(h.transports, transport.SessionID) - h.mu.Unlock() - transport.connection.Close() + if transport != nil { // transport may be nil in stateless mode + h.mu.Lock() + delete(h.transports, transport.SessionID) + h.mu.Unlock() + transport.connection.Close() + } w.WriteHeader(http.StatusNoContent) return } switch req.Method { case http.MethodPost, http.MethodGet: - if req.Method == http.MethodGet && transport == nil { + if req.Method == http.MethodGet && sessionID == "" { http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed) return } @@ -164,37 +178,83 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "no server available", http.StatusBadRequest) return } - sessionID := h.opts.GetSessionID() - s := &StreamableServerTransport{SessionID: sessionID, jsonResponse: h.opts.jsonResponse} + if sessionID == "" { + // In stateless mode, sessionID may be nonempty even if there's no + // existing transport. + sessionID = h.opts.GetSessionID() + } + transport = &StreamableServerTransport{ + SessionID: sessionID, + Stateless: h.opts.Stateless, + jsonResponse: h.opts.jsonResponse, + } // To support stateless mode, we initialize the session with a default // state, so that it doesn't reject subsequent requests. var connectOpts *ServerSessionOptions - if sessionID == "" { + if h.opts.Stateless { + // Peek at the body to see if it is initialize or initialized. + // We want those to be handled as usual. + var hasInitialize, hasInitialized bool + { + // TODO: verify that this allows protocol version negotiation for + // stateless servers. + body, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + req.Body.Close() + + // Reset the body so that it can be read later. + req.Body = io.NopCloser(bytes.NewBuffer(body)) + + msgs, _, err := readBatch(body) + if err == nil { + for _, msg := range msgs { + if req, ok := msg.(*jsonrpc.Request); ok { + switch req.Method { + case methodInitialize: + hasInitialize = true + case notificationInitialized: + hasInitialized = true + } + } + } + } + } + + // If we don't have InitializeParams or InitializedParams in the request, + // set the initial state to a default value. + state := new(ServerSessionState) + if !hasInitialize { + state.InitializeParams = new(InitializeParams) + } + if !hasInitialized { + state.InitializedParams = new(InitializedParams) + } connectOpts = &ServerSessionOptions{ - State: &ServerSessionState{ - InitializeParams: new(InitializeParams), - InitializedParams: new(InitializedParams), - }, + State: state, } } + // Pass req.Context() here, to allow middleware to add context values. // The context is detached in the jsonrpc2 library when handling the // long-running stream. - ss, err := server.Connect(req.Context(), s, connectOpts) + ss, err := server.Connect(req.Context(), transport, connectOpts) if err != nil { http.Error(w, "failed connection", http.StatusInternalServerError) return } - if sessionID == "" { + if h.opts.Stateless { // Stateless mode: close the session when the request exits. defer ss.Close() // close the fake session after handling the request } else { + // Otherwise, save the transport so that it can be reused h.mu.Lock() - h.transports[s.SessionID] = s + h.transports[transport.SessionID] = transport h.mu.Unlock() } - transport = s } transport.ServeHTTP(w, req) @@ -225,6 +285,13 @@ type StreamableServerTransport struct { // generator to produce one, as with [crypto/rand.Text].) SessionID string + // Stateless controls whether the eventstore is 'Stateless'. Servers sessions + // connected to a stateless transport are disallowed from making outgoing + // requests. + // + // See also [StreamableHTTPOptions.Stateless]. + Stateless bool + // Storage for events, to enable stream resumption. // If nil, a [MemoryEventStore] with the default maximum size will be used. EventStore EventStore @@ -265,6 +332,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) } t.connection = &streamableServerConn{ sessionID: t.SessionID, + stateless: t.Stateless, eventStore: t.EventStore, jsonResponse: t.jsonResponse, incoming: make(chan jsonrpc.Message, 10), @@ -285,6 +353,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) type streamableServerConn struct { sessionID string + stateless bool jsonResponse bool eventStore EventStore @@ -755,6 +824,10 @@ func (c *streamableServerConn) Read(ctx context.Context) (jsonrpc.Message, error // Write implements the [Connection] interface. func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { + if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() && (c.stateless || c.sessionID == "") { + // Requests aren't possible with stateless servers, or when there's no session ID. + return fmt.Errorf("%w: stateless servers cannot make requests", jsonrpc2.ErrRejected) + } // Find the incoming request that this write relates to, if any. var forRequest jsonrpc.ID isResponse := false diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 7c77938e..8334bc0d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -748,7 +748,7 @@ func testStreamableHandler(t *testing.T, handler http.Handler, requests []stream if !request.ignoreResponse { transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) if diff := cmp.Diff(request.wantMessages, got, transform); diff != "" { - t.Errorf("received unexpected messages (-want +got):\n%s", diff) + t.Errorf("request #%d: received unexpected messages (-want +got):\n%s", i, diff) } } sessionID.CompareAndSwap("", gotSessionID) @@ -996,19 +996,18 @@ func TestEventID(t *testing.T) { } func TestStreamableStateless(t *testing.T) { - // This version of sayHi doesn't make a ping request (we can't respond to + // This version of sayHi expects // that request from our client). sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResult, error) { + if err := req.Session.Ping(ctx, nil); err == nil { + // ping should fail, but not break the connection + t.Errorf("ping succeeded unexpectedly") + } return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil } server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) - // Test stateless mode. - handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ - GetSessionID: func() string { return "" }, - }) - requests := []streamableRequest{ { method: "POST", @@ -1028,7 +1027,74 @@ func TestStreamableStateless(t *testing.T) { }, wantSessionID: false, }, + { + method: "POST", + wantStatusCode: http.StatusOK, + messages: []jsonrpc.Message{ + req(2, "tools/call", &CallToolParams{Name: "greet", Arguments: hiParams{Name: "foo"}}), + }, + wantMessages: []jsonrpc.Message{ + resp(2, &CallToolResult{Content: []Content{&TextContent{Text: "hi foo"}}}, nil), + }, + wantSessionID: false, + }, + } + + testClientCompatibility := func(t *testing.T, handler http.Handler) { + ctx := context.Background() + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + cs, err := NewClient(testImpl, nil).Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil) + if err != nil { + t.Fatal(err) + } + res, err := cs.CallTool(ctx, &CallToolParams{Name: "greet", Arguments: hiParams{Name: "bar"}}) + if err != nil { + t.Fatal(err) + } + if got, want := textContent(t, res), "hi bar"; got != want { + t.Errorf("Result = %q, want %q", got, want) + } } - testStreamableHandler(t, handler, requests) + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + GetSessionID: func() string { return "" }, + Stateless: true, + }) + + // Test the default stateless mode. + t.Run("stateless", func(t *testing.T) { + testStreamableHandler(t, handler, requests) + testClientCompatibility(t, handler) + }) + + // Test a "distributed" variant of stateless mode, where it has non-empty + // session IDs, but is otherwise stateless. + // + // This can be used by tools to look up application state preserved across + // subsequent requests. + for i, req := range requests { + // Now, we want a session for all requests. + req.wantSessionID = true + requests[i] = req + } + distributableHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + Stateless: true, + }) + t.Run("distributed", func(t *testing.T) { + testStreamableHandler(t, distributableHandler, requests) + testClientCompatibility(t, handler) + }) +} + +func textContent(t *testing.T, res *CallToolResult) string { + t.Helper() + if len(res.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(res.Content)) + } + text, ok := res.Content[0].(*TextContent) + if !ok { + t.Fatalf("Content[0] is %T, want *TextContent", res.Content[0]) + } + return text.Text } diff --git a/mcp/transport.go b/mcp/transport.go index 6d25de33..8018910b 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -40,8 +40,8 @@ type Transport interface { type Connection interface { // Read reads the next message to process off the connection. // - // Read need not be safe for concurrent use: Read is called in a - // concurrency-safe manner by the JSON-RPC library. + // Connections must allow Read to be called concurrently with Close. In + // particular, calling Close should unblock a Read waiting for input. Read(context.Context) (jsonrpc.Message, error) // Write writes a new message to the connection. From 1a54234c27373f3bf46093da2b33a32198cc7427 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Mon, 18 Aug 2025 20:14:00 +0000 Subject: [PATCH 112/221] mcp: fix reconnect semantics for hanging GET A few problems with reconnection cropped up in the review of PR #307. We should allow for the hanging GET to fail with StatusMethodNotAllowed. This simply means that the server does not support sending notifications or requests over the GET, which is allowed in the spec. Also, we should fix the initial delay of the hanging GET request: it should start with 0 delay. Fix the math for this and subsequent attempts. Incidentally, this makes the tests take 3s on my machine, down from 9s. Also address some comments from #307. --- internal/jsonrpc2/conn.go | 4 +-- mcp/streamable.go | 68 +++++++++++++++++++++++---------------- 2 files changed, 43 insertions(+), 29 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 537be47a..49902b00 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -742,8 +742,8 @@ func (c *Connection) write(ctx context.Context, msg Message) error { var err error // Fail writes immediately if the connection is shutting down. // - // TODO(rfindley): should we allow cancellation notifations through? It could - // be the case that writes can still succeed. + // TODO(rfindley): should we allow cancellation notifications through? It + // could be the case that writes can still succeed. c.updateInFlight(func(s *inFlightState) { err = s.shuttingDown(ErrServerClosing) }) diff --git a/mcp/streamable.go b/mcp/streamable.go index 789c7a0b..572fe5de 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -49,6 +49,9 @@ type StreamableHTTPOptions struct { // GetSessionID provides the next session ID to use for an incoming request. // If nil, a default randomly generated ID will be used. // + // Session IDs should be globally unique across the scope of the server, + // which may span multiple processes in the case of distributed servers. + // // As a special case, if GetSessionID returns the empty string, the // Mcp-Session-Id header will not be set. GetSessionID func() string @@ -58,7 +61,9 @@ type StreamableHTTPOptions struct { // A stateless server does not validate the Mcp-Session-Id header, and uses a // temporary session with default initialization parameters. Any // server->client request is rejected immediately as there's no way for the - // client to respond. + // client to respond. Server->Client notifications may reach the client if + // they are made in the context of an incoming request, as described in the + // documentation for [StreamableServerTransport]. Stateless bool // TODO: support session retention (?) @@ -133,12 +138,13 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque var transport *StreamableServerTransport if sessionID != "" { h.mu.Lock() - transport, _ = h.transports[sessionID] + transport = h.transports[sessionID] h.mu.Unlock() if transport == nil && !h.opts.Stateless { - // In stateless mode we allow a missing transport. + // Unless we're in 'stateless' mode, which doesn't perform any Session-ID + // validation, we require that the session ID matches a known session. // - // A synthetic transport will be created below for the transient session. + // In stateless mode, a temporary transport is be created below. http.Error(w, "session not found", http.StatusNotFound) return } @@ -201,7 +207,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // stateless servers. body, err := io.ReadAll(req.Body) if err != nil { - http.Error(w, "failed to read body", http.StatusBadRequest) + http.Error(w, "failed to read body", http.StatusInternalServerError) return } req.Body.Close() @@ -272,9 +278,22 @@ type StreamableServerTransportOptions struct { // A StreamableServerTransport implements the server side of the MCP streamable // transport. // -// Each StreamableServerTransport may be connected (via [Server.Connect]) at +// Each StreamableServerTransport must be connected (via [Server.Connect]) at // most once, since [StreamableServerTransport.ServeHTTP] serves messages to // the connected session. +// +// Reads from the streamable server connection receive messages from http POST +// requests from the client. Writes to the streamable server connection are +// sent either to the hanging POST response, or to the hanging GET, according +// to the following rules: +// - JSON-RPC responses to incoming requests are always routed to the +// appropriate HTTP response. +// - Requests or notifications made with a context.Context value derived from +// an incoming request handler, are routed to the HTTP response +// corresponding to that request, unless it has already terminated, in +// which case they are routed to the hanging GET. +// - Requests or notifications made with a detached context.Context value are +// routed to the hanging GET. type StreamableServerTransport struct { // SessionID is the ID of this session. // @@ -285,7 +304,7 @@ type StreamableServerTransport struct { // generator to produce one, as with [crypto/rand.Text].) SessionID string - // Stateless controls whether the eventstore is 'Stateless'. Servers sessions + // Stateless controls whether the eventstore is 'Stateless'. Server sessions // connected to a stateless transport are disallowed from making outgoing // requests. // @@ -1225,9 +1244,18 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent c.fail(err) return } - - // Reconnection was successful. Continue the loop with the new response. resp = newResp + if resp.StatusCode == http.StatusMethodNotAllowed && persistent { + // The server doesn't support the hanging GET. + resp.Body.Close() + return + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + resp.Body.Close() + c.fail(fmt.Errorf("failed to reconnect: %v", http.StatusText(resp.StatusCode))) + return + } + // Reconnection was successful. Continue the loop with the new response. } } @@ -1295,13 +1323,6 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er finalErr = err // Store the error and try again. continue } - - if !isResumable(resp) { - // The server indicated we should not continue. - resp.Body.Close() - return nil, fmt.Errorf("reconnection failed with unresumable status: %s", resp.Status) - } - return resp, nil } } @@ -1312,16 +1333,6 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er return nil, fmt.Errorf("connection failed after %d attempts", c.maxRetries) } -// isResumable checks if an HTTP response indicates a valid SSE stream that can be processed. -func isResumable(resp *http.Response) bool { - // Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint. - if resp.StatusCode == http.StatusMethodNotAllowed { - return false - } - - return strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") -} - // Close implements the [Connection] interface. func (c *streamableClientConn) Close() error { c.closeOnce.Do(func() { @@ -1361,8 +1372,11 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response, // calculateReconnectDelay calculates a delay using exponential backoff with full jitter. func calculateReconnectDelay(attempt int) time.Duration { + if attempt == 0 { + return 0 + } // Calculate the exponential backoff using the grow factor. - backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt))) + backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt-1))) // Cap the backoffDuration at maxDelay. backoffDuration = min(backoffDuration, reconnectMaxDelay) From 8e6ab130d8555fe447503e611b1f5906f35b2dd1 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 20 Aug 2025 10:50:16 -0400 Subject: [PATCH 113/221] mcp: pass TokenInfo to server handler (#292) If there is a TokenInfo in the request context of a StreamableServerTransport, then propagate it through to the ServerRequest that is passed to server methods like callTool. Fixes #317. --- auth/auth.go | 27 ++++++++++++++++--- internal/jsonrpc2/messages.go | 6 +++++ mcp/shared.go | 25 +++++++++++------- mcp/streamable.go | 14 ++++++++++ mcp/streamable_test.go | 49 +++++++++++++++++++++++++++++++++++ mcp/tool.go | 1 + mcp/transport.go | 3 +-- 7 files changed, 109 insertions(+), 16 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 68873b48..14ad28c7 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -13,22 +13,41 @@ import ( "time" ) +// TokenInfo holds information from a bearer token. type TokenInfo struct { Scopes []string Expiration time.Time + // TODO: add standard JWT fields + Extra map[string]any } +// The error that a TokenVerifier should return if the token cannot be verified. +var ErrInvalidToken = errors.New("invalid token") + +// A TokenVerifier checks the validity of a bearer token, and extracts information +// from it. If verification fails, it should return an error that unwraps to ErrInvalidToken. type TokenVerifier func(ctx context.Context, token string) (*TokenInfo, error) +// RequireBearerTokenOptions are options for [RequireBearerToken]. type RequireBearerTokenOptions struct { - Scopes []string + // The URL for the resource server metadata OAuth flow, to be returned as part + // of the WWW-Authenticate header. ResourceMetadataURL string + // The required scopes. + Scopes []string } -var ErrInvalidToken = errors.New("invalid token") - type tokenInfoKey struct{} +// TokenInfoFromContext returns the [TokenInfo] stored in ctx, or nil if none. +func TokenInfoFromContext(ctx context.Context) *TokenInfo { + ti := ctx.Value(tokenInfoKey{}) + if ti == nil { + return nil + } + return ti.(*TokenInfo) +} + // RequireBearerToken returns a piece of middleware that verifies a bearer token using the verifier. // If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds. // If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header @@ -75,7 +94,7 @@ func verify(ctx context.Context, verifier TokenVerifier, opts *RequireBearerToke return nil, err.Error(), http.StatusInternalServerError } - // Check scopes. + // Check scopes. All must be present. if opts != nil { // Note: quadratic, but N is small. for _, s := range opts.Scopes { diff --git a/internal/jsonrpc2/messages.go b/internal/jsonrpc2/messages.go index 2de3d4f0..9c0d5d69 100644 --- a/internal/jsonrpc2/messages.go +++ b/internal/jsonrpc2/messages.go @@ -56,6 +56,9 @@ type Request struct { Method string // Params is either a struct or an array with the parameters of the method. Params json.RawMessage + // Extra is additional information that does not appear on the wire. It can be + // used to pass information from the application to the underlying transport. + Extra any } // Response is a Message used as a reply to a call Request. @@ -67,6 +70,9 @@ type Response struct { Error error // id of the request this is a response to. ID ID + // Extra is additional information that does not appear on the wire. It can be + // used to pass information from the underlying transport to the application. + Extra any } // StringID creates a new string request identifier. diff --git a/mcp/shared.go b/mcp/shared.go index 608e2aaf..bda631fe 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -19,6 +19,7 @@ import ( "strings" "time" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -126,7 +127,8 @@ func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Requ } mh := session.receivingMethodHandler() - req := info.newRequest(session, params) + re, _ := jreq.Extra.(*RequestExtra) + req := info.newRequest(session, params, re) // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. res, err := mh(ctx, jreq.Method, req) if err != nil { @@ -173,7 +175,7 @@ type methodInfo struct { // Unmarshal params from the wire into a Params struct. // Used on the receive side. unmarshalParams func(json.RawMessage) (Params, error) - newRequest func(Session, Params) Request + newRequest func(Session, Params, *RequestExtra) Request // Run the code when a call to the method is received. // Used on the receive side. handleMethod MethodHandler @@ -208,7 +210,7 @@ const ( func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHandler[P, R], flags methodFlags) methodInfo { mi := newMethodInfo[P, R](flags) - mi.newRequest = func(s Session, p Params) Request { + mi.newRequest = func(s Session, p Params, _ *RequestExtra) Request { r := &ClientRequest[P]{Session: s.(*ClientSession)} if p != nil { r.Params = p.(P) @@ -223,19 +225,15 @@ func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHan func newServerMethodInfo[P paramsPtr[T], R Result, T any](d typedServerMethodHandler[P, R], flags methodFlags) methodInfo { mi := newMethodInfo[P, R](flags) - mi.newRequest = func(s Session, p Params) Request { - r := &ServerRequest[P]{Session: s.(*ServerSession)} + mi.newRequest = func(s Session, p Params, re *RequestExtra) Request { + r := &ServerRequest[P]{Session: s.(*ServerSession), Extra: re} if p != nil { r.Params = p.(P) } return r } mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) { - rf := &ServerRequest[P]{Session: req.GetSession().(*ServerSession)} - if req.GetParams() != nil { - rf.Params = req.GetParams().(P) - } - return d(ctx, rf) + return d(ctx, req.(*ServerRequest[P])) }) return mi } @@ -391,6 +389,13 @@ type ClientRequest[P Params] struct { type ServerRequest[P Params] struct { Session *ServerSession Params P + Extra *RequestExtra +} + +// RequestExtra is extra information included in requests, typically from +// the transport layer. +type RequestExtra struct { + TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any } func (*ClientRequest[P]) isRequest() {} diff --git a/mcp/streamable.go b/mcp/streamable.go index 572fe5de..c51f3cc4 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -21,6 +21,7 @@ import ( "sync/atomic" "time" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -579,12 +580,17 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // This also requires access to the negotiated version, which would either be // set by the MCP-Protocol-Version header, or would require peeking into the // session. + if err != nil { + http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) + return + } incoming, _, err := readBatch(body) if err != nil { http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) return } requests := make(map[jsonrpc.ID]struct{}) + tokenInfo := auth.TokenInfoFromContext(req.Context()) for _, msg := range incoming { if req, ok := msg.(*jsonrpc.Request); ok { // Preemptively check that this is a valid request, so that we can fail @@ -594,6 +600,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques http.Error(w, err.Error(), http.StatusBadRequest) return } + req.Extra = &RequestExtra{TokenInfo: tokenInfo} if req.IsCall() { requests[req.ID] = struct{}{} } @@ -1182,6 +1189,10 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } +// testAuth controls whether a fake Authorization header is added to outgoing requests. +// TODO: replace with a better mechanism when client-side auth is in place. +var testAuth = false + func (c *streamableClientConn) setMCPHeaders(req *http.Request) { c.mu.Lock() defer c.mu.Unlock() @@ -1192,6 +1203,9 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) { if c.sessionID != "" { req.Header.Set(sessionIDHeader, c.sessionID) } + if testAuth { + req.Header.Set("Authorization", "Bearer foo") + } } func (c *streamableClientConn) handleJSON(resp *http.Response) { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 8334bc0d..93eafb4a 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -26,6 +26,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -1098,3 +1099,51 @@ func textContent(t *testing.T, res *CallToolResult) string { } return text.Text } + +func TestTokenInfo(t *testing.T) { + defer func(b bool) { testAuth = b }(testAuth) + testAuth = true + ctx := context.Background() + + // Create a server with a tool that returns TokenInfo. + tokenInfo := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[struct{}]]) (*CallToolResultFor[any], error) { + return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.Extra.TokenInfo)}}}, nil + } + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo) + + streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + verifier := func(context.Context, string) (*auth.TokenInfo, error) { + return &auth.TokenInfo{ + Scopes: []string{"scope"}, + // Expiration is far, far in the future. + Expiration: time.Date(5000, 1, 2, 3, 4, 5, 0, time.UTC), + }, nil + } + handler := auth.RequireBearerToken(verifier, nil)(streamHandler) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + transport := NewStreamableClientTransport(httpServer.URL, nil) + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer session.Close() + + res, err := session.CallTool(ctx, &CallToolParams{Name: "tokenInfo"}) + if err != nil { + t.Fatal(err) + } + if len(res.Content) == 0 { + t.Fatal("missing content") + } + tc, ok := res.Content[0].(*TextContent) + if !ok { + t.Fatal("not TextContent") + } + if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w { + t.Errorf("got %q, want %q", g, w) + } +} diff --git a/mcp/tool.go b/mcp/tool.go index 15f17e11..7173b8a8 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -68,6 +68,7 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool res, err := h(ctx, &ServerRequest[*CallToolParamsFor[In]]{ Session: req.Session, Params: params, + Extra: req.Extra, }) // TODO(rfindley): investigate why server errors are embedded in this strange way, // rather than returned as jsonrpc2 server errors. diff --git a/mcp/transport.go b/mcp/transport.go index 8018910b..2bcd8d7d 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -86,8 +86,7 @@ type serverConnection interface { // A StdioTransport is a [Transport] that communicates over stdin/stdout using // newline-delimited JSON. -type StdioTransport struct { -} +type StdioTransport struct{} // Connect implements the [Transport] interface. func (*StdioTransport) Connect(context.Context) (Connection, error) { From 719f68d0bf97de8d3a802da0bf237cf71b99da04 Mon Sep 17 00:00:00 2001 From: Shusaku Yasoda <136243871+yasomaru@users.noreply.github.com> Date: Thu, 21 Aug 2025 01:31:11 +0900 Subject: [PATCH 114/221] improve error handling for tools and server methods (#311) - Fix tool error handling to properly distinguish between protocol errors and tool execution errors - Return structured JSON-RPC errors directly for protocol-level issues - Embed regular tool errors in CallToolResult as per MCP specification - Improve server error responses for unknown prompts and tools - Add comprehensive tests for both structured and regular error handling --- mcp/server.go | 12 +++++-- mcp/shared.go | 2 ++ mcp/tool.go | 12 +++++-- mcp/tool_test.go | 92 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 113 insertions(+), 5 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 88021336..65115afa 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -328,8 +328,11 @@ func (s *Server) getPrompt(ctx context.Context, req *ServerRequest[*GetPromptPar prompt, ok := s.prompts.get(req.Params.Name) s.mu.Unlock() if !ok { - // TODO: surface the error code over the wire, instead of flattening it into the string. - return nil, fmt.Errorf("%s: unknown prompt %q", jsonrpc2.ErrInvalidParams, req.Params.Name) + // Return a proper JSON-RPC error with the correct error code + return nil, &jsonrpc2.WireError{ + Code: CodeInvalidParams, + Message: fmt.Sprintf("unknown prompt %q", req.Params.Name), + } } return prompt.handler(ctx, req.Session, req.Params) } @@ -353,7 +356,10 @@ func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParam st, ok := s.tools.get(req.Params.Name) s.mu.Unlock() if !ok { - return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, req.Params.Name) + return nil, &jsonrpc2.WireError{ + Code: CodeInvalidParams, + Message: fmt.Sprintf("unknown tool %q", req.Params.Name), + } } // TODO: if handler returns nil content, it will serialize as null. // Add a test and fix. diff --git a/mcp/shared.go b/mcp/shared.go index bda631fe..6aafcc77 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -311,6 +311,8 @@ const ( CodeResourceNotFound = -32002 // The error code if the method exists and was called properly, but the peer does not support it. CodeUnsupportedMethod = -31001 + // The error code for invalid parameters + CodeInvalidParams = -32602 ) // notifySessions calls Notify on all the sessions. diff --git a/mcp/tool.go b/mcp/tool.go index 7173b8a8..893b48ff 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -12,6 +12,7 @@ import ( "reflect" "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) // A ToolHandler handles a call to tools/call. @@ -70,9 +71,16 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool Params: params, Extra: req.Extra, }) - // TODO(rfindley): investigate why server errors are embedded in this strange way, - // rather than returned as jsonrpc2 server errors. + // Handle server errors appropriately: + // - If the handler returns a structured error (like jsonrpc2.WireError), return it directly + // - If the handler returns a regular error, wrap it in a CallToolResult with IsError=true + // - This allows tools to distinguish between protocol errors and tool execution errors if err != nil { + // Check if this is already a structured JSON-RPC error + if wireErr, ok := err.(*jsonrpc2.WireError); ok { + return nil, wireErr + } + // For regular errors, embed them in the tool result as per MCP spec return &CallToolResult{ Content: []Content{&TextContent{Text: err.Error()}}, IsError: true, diff --git a/mcp/tool_test.go b/mcp/tool_test.go index 609536cc..4c73ec63 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -7,12 +7,16 @@ package mcp import ( "context" "encoding/json" + "errors" + "fmt" "reflect" + "strings" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) // testToolHandler is used for type inference in TestNewServerTool. @@ -132,3 +136,91 @@ func TestUnmarshalSchema(t *testing.T) { } } + +func TestToolErrorHandling(t *testing.T) { + // Construct server and add both tools at the top level + server := NewServer(testImpl, nil) + + // Create a tool that returns a structured error + structuredErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResultFor[any], error) { + return nil, &jsonrpc2.WireError{ + Code: CodeInvalidParams, + Message: "internal server error", + } + } + + // Create a tool that returns a regular error + regularErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResultFor[any], error) { + return nil, fmt.Errorf("tool execution failed") + } + + AddTool(server, &Tool{Name: "error_tool", Description: "returns structured error"}, structuredErrorHandler) + AddTool(server, &Tool{Name: "regular_error_tool", Description: "returns regular error"}, regularErrorHandler) + + // Connect server and client once + ct, st := NewInMemoryTransports() + _, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatal(err) + } + + client := NewClient(testImpl, nil) + cs, err := client.Connect(context.Background(), ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Test that structured JSON-RPC errors are returned directly + t.Run("structured_error", func(t *testing.T) { + // Call the tool + _, err = cs.CallTool(context.Background(), &CallToolParams{ + Name: "error_tool", + Arguments: map[string]any{}, + }) + + // Should get the structured error directly + if err == nil { + t.Fatal("expected error, got nil") + } + + var wireErr *jsonrpc2.WireError + if !errors.As(err, &wireErr) { + t.Fatalf("expected WireError, got %[1]T: %[1]v", err) + } + + if wireErr.Code != CodeInvalidParams { + t.Errorf("expected error code %d, got %d", CodeInvalidParams, wireErr.Code) + } + }) + + // Test that regular errors are embedded in tool results + t.Run("regular_error", func(t *testing.T) { + // Call the tool + result, err := cs.CallTool(context.Background(), &CallToolParams{ + Name: "regular_error_tool", + Arguments: map[string]any{}, + }) + + // Should not get an error at the protocol level + if err != nil { + t.Fatalf("unexpected protocol error: %v", err) + } + + // Should get a result with IsError=true + if !result.IsError { + t.Error("expected IsError=true, got false") + } + + // Should have error message in content + if len(result.Content) == 0 { + t.Error("expected error content, got empty") + } + + if textContent, ok := result.Content[0].(*TextContent); !ok { + t.Error("expected TextContent") + } else if !strings.Contains(textContent.Text, "tool execution failed") { + t.Errorf("expected error message in content, got: %s", textContent.Text) + } + }) +} From 62d815922ffd902b0a50b4a0107610a629ce8542 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Wed, 20 Aug 2025 12:55:19 -0400 Subject: [PATCH 115/221] mcp: minor cleanup of some TODOs (#332) Do a pass through TODOs to delete or address minor TODOs. --- mcp/client.go | 1 - mcp/features.go | 4 +++- mcp/server.go | 13 +++++++------ mcp/shared.go | 6 ++++-- mcp/streamable.go | 1 - 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index d1d17502..fdf05417 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -226,7 +226,6 @@ func (c *Client) AddRoots(roots ...*Root) { // RemoveRoots removes the roots with the given URIs, // and notifies any connected servers if the list has changed. // It is not an error to remove a nonexistent root. -// TODO: notification func (c *Client) RemoveRoots(uris ...string) { changeAndNotify(c, notificationRootsListChanged, &RootsListChangedParams{}, func() bool { return c.roots.remove(uris...) }) diff --git a/mcp/features.go b/mcp/features.go index 43c58854..438370fe 100644 --- a/mcp/features.go +++ b/mcp/features.go @@ -17,7 +17,9 @@ import ( // A featureSet is a collection of features of type T. // Every feature has a unique ID, and the spec never mentions // an ordering for the List calls, so what it calls a "list" is actually a set. -// TODO: switch to an ordered map +// +// An alternative implementation would use an ordered map, but that's probably +// not necessary as adds and removes are rare, and usually batched. type featureSet[T any] struct { uniqueID func(T) string features map[string]T diff --git a/mcp/server.go b/mcp/server.go index 65115afa..2a8ab93b 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -89,17 +89,18 @@ type ServerOptions struct { // The first argument must not be nil. // // If non-nil, the provided options are used to configure the server. -func NewServer(impl *Implementation, opts *ServerOptions) *Server { +func NewServer(impl *Implementation, options *ServerOptions) *Server { if impl == nil { panic("nil Implementation") } - if opts == nil { - opts = new(ServerOptions) + var opts ServerOptions + if options != nil { + opts = *options } + options = nil // prevent reuse if opts.PageSize < 0 { panic(fmt.Errorf("invalid page size %d", opts.PageSize)) } - // TODO(jba): don't modify opts, modify Server.opts. if opts.PageSize == 0 { opts.PageSize = DefaultPageSize } @@ -111,7 +112,7 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server { } return &Server{ impl: impl, - opts: *opts, + opts: opts, prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }), @@ -463,7 +464,7 @@ func fileResourceHandler(dir string) ResourceHandler { return func(ctx context.Context, req *ServerRequest[*ReadResourceParams]) (_ *ReadResourceResult, err error) { defer util.Wrapf(&err, "reading resource %s", req.Params.URI) - // TODO: use a memoizing API here. + // TODO(#25): use a memoizing API here. rootRes, err := req.Session.ListRoots(ctx, nil) if err != nil { return nil, fmt.Errorf("listing roots: %w", err) diff --git a/mcp/shared.go b/mcp/shared.go index 6aafcc77..c2162fb6 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -4,8 +4,10 @@ // This file contains code shared between client and server, including // method handler and middleware definitions. -// TODO: much of this is here so that we can factor out commonalities using -// generics. Perhaps it can be simplified with reflection. +// +// Much of this is here so that we can factor out commonalities using +// generics. If this becomes unwieldy, it can perhaps be simplified with +// reflection. package mcp diff --git a/mcp/streamable.go b/mcp/streamable.go index c51f3cc4..89ff45c2 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1152,7 +1152,6 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e } if resp.StatusCode < 200 || resp.StatusCode >= 300 { - // TODO: do a best effort read of the body here, and format it in the error. resp.Body.Close() return fmt.Errorf("broken session: %v", resp.Status) } From 48abccbf6b34ab82b3b7dca5843f4962fd532da5 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Wed, 20 Aug 2025 14:39:17 -0400 Subject: [PATCH 116/221] mcp/server: expose InitializeParams to ServerSession (#336) This CL enables the server session to see what capabilities the client has by introducing the InitializeParams() function. This CL also adds a test to ensure ClientCapabilities is accurate. Fixes #141 --- mcp/client.go | 17 ++++++++++------- mcp/client_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ mcp/server.go | 2 ++ 3 files changed, 57 insertions(+), 7 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index fdf05417..2511c05b 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -105,6 +105,15 @@ func (e unsupportedProtocolVersionError) Error() string { // ClientSessionOptions is reserved for future use. type ClientSessionOptions struct{} +func (c *Client) capabilities() *ClientCapabilities { + caps := &ClientCapabilities{} + caps.Roots.ListChanged = true + if c.opts.CreateMessageHandler != nil { + caps.Sampling = &SamplingCapabilities{} + } + return caps +} + // Connect begins an MCP session by connecting to a server over the given // transport, and initializing the session. // @@ -118,16 +127,10 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio return nil, err } - caps := &ClientCapabilities{} - caps.Roots.ListChanged = true - if c.opts.CreateMessageHandler != nil { - caps.Sampling = &SamplingCapabilities{} - } - params := &InitializeParams{ ProtocolVersion: latestProtocolVersion, ClientInfo: c.impl, - Capabilities: caps, + Capabilities: c.capabilities(), } req := &ClientRequest[*InitializeParams]{Session: cs, Params: params} res, err := handleSend[*InitializeResult](ctx, methodInitialize, req) diff --git a/mcp/client_test.go b/mcp/client_test.go index 7920c55c..469fa3fb 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -190,3 +190,48 @@ func TestClientPaginateVariousPageSizes(t *testing.T) { }) } } + +func TestClientCapabilities(t *testing.T) { + testCases := []struct { + name string + configureClient func(s *Client) + clientOpts ClientOptions + wantCapabilities *ClientCapabilities + }{ + { + name: "With initial capabilities", + configureClient: func(s *Client) {}, + wantCapabilities: &ClientCapabilities{ + Roots: struct { + ListChanged bool "json:\"listChanged,omitempty\"" + }{ListChanged: true}, + }, + }, + { + name: "With sampling", + configureClient: func(s *Client) {}, + clientOpts: ClientOptions{ + CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + return nil, nil + }, + }, + wantCapabilities: &ClientCapabilities{ + Roots: struct { + ListChanged bool "json:\"listChanged,omitempty\"" + }{ListChanged: true}, + Sampling: &SamplingCapabilities{}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + client := NewClient(testImpl, &tc.clientOpts) + tc.configureClient(client) + gotCapabilities := client.capabilities() + if diff := cmp.Diff(tc.wantCapabilities, gotCapabilities); diff != "" { + t.Errorf("capabilities() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/mcp/server.go b/mcp/server.go index 2a8ab93b..d98ff8ab 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -849,6 +849,8 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, return handleReceive(ctx, ss, req) } +func (ss *ServerSession) InitializeParams() *InitializeParams { return ss.state.InitializeParams } + func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) { if params == nil { return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) From 73b8a7f134b44d522dba1d20e3b4376bd58baa0c Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 20 Aug 2025 14:58:44 -0400 Subject: [PATCH 117/221] mcp: make request headers available to tools (#333) Add HTTP request headers to RequestExtra, so tools and other user-defined handlers can access them. Fixes #331. --- mcp/shared.go | 2 ++ mcp/streamable.go | 17 ++++++++--------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/mcp/shared.go b/mcp/shared.go index c2162fb6..b7b8bda1 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -16,6 +16,7 @@ import ( "encoding/json" "fmt" "log" + "net/http" "reflect" "slices" "strings" @@ -400,6 +401,7 @@ type ServerRequest[P Params] struct { // the transport layer. type RequestExtra struct { TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any + Header http.Header // header from HTTP request, if any } func (*ClientRequest[P]) isRequest() {} diff --git a/mcp/streamable.go b/mcp/streamable.go index 89ff45c2..eb42ca57 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -580,10 +580,6 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // This also requires access to the negotiated version, which would either be // set by the MCP-Protocol-Version header, or would require peeking into the // session. - if err != nil { - http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) - return - } incoming, _, err := readBatch(body) if err != nil { http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) @@ -592,17 +588,20 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques requests := make(map[jsonrpc.ID]struct{}) tokenInfo := auth.TokenInfoFromContext(req.Context()) for _, msg := range incoming { - if req, ok := msg.(*jsonrpc.Request); ok { + if jreq, ok := msg.(*jsonrpc.Request); ok { // Preemptively check that this is a valid request, so that we can fail // the HTTP request. If we didn't do this, a request with a bad method or // missing ID could be silently swallowed. - if _, err := checkRequest(req, serverMethodInfos); err != nil { + if _, err := checkRequest(jreq, serverMethodInfos); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - req.Extra = &RequestExtra{TokenInfo: tokenInfo} - if req.IsCall() { - requests[req.ID] = struct{}{} + jreq.Extra = &RequestExtra{ + TokenInfo: tokenInfo, + Header: req.Header, + } + if jreq.IsCall() { + requests[jreq.ID] = struct{}{} } } } From 3f10c19b45f6ccbe22064fd0d66e1bd4b9a11937 Mon Sep 17 00:00:00 2001 From: Hongxiang Jiang <58842577+h9jiang@users.noreply.github.com> Date: Wed, 20 Aug 2025 15:01:44 -0400 Subject: [PATCH 118/221] mcp: replace type struct{} with any for user-defined fields (#334) This change adopts a clearer convention: - any is used for fields where the structure is defined by the client or server implementer. - an named empty struct type is the place holder type to be defined by the MCP spec in the future. The "Experimental" fields in "ClientCapabilities" and "ServerCapabilities" now use "map[string]any". This allows clients and servers to negotiate custom capabilities that contain complex data, rather than being restricted to a simple on/off flag. The "Metadata" field in "CreateMessageParams" was also changed to "any" to properly serve its purpose as a flexible container for server-defined data. --- mcp/protocol.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mcp/protocol.go b/mcp/protocol.go index 0125cb13..666c1bc7 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -126,7 +126,7 @@ func (x *CancelledParams) SetProgressToken(t any) { setProgressToken(x, t) } // additional capabilities. type ClientCapabilities struct { // Experimental, non-standard capabilities that the client supports. - Experimental map[string]struct{} `json:"experimental,omitempty"` + Experimental map[string]any `json:"experimental,omitempty"` // Present if the client supports listing roots. Roots struct { // Whether the client supports notifications for changes to the roots list. @@ -242,7 +242,7 @@ type CreateMessageParams struct { Messages []*SamplingMessage `json:"messages"` // Optional metadata to pass through to the LLM provider. The format of this // metadata is provider-specific. - Metadata struct{} `json:"metadata,omitempty"` + Metadata any `json:"metadata,omitempty"` // The server's preferences for which model to select. The client may ignore // these preferences. ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` @@ -997,7 +997,7 @@ type ServerCapabilities struct { // Present if the server supports argument autocompletion suggestions. Completions *CompletionCapabilities `json:"completions,omitempty"` // Experimental, non-standard capabilities that the server supports. - Experimental map[string]struct{} `json:"experimental,omitempty"` + Experimental map[string]any `json:"experimental,omitempty"` // Present if the server supports sending log messages to the client. Logging *LoggingCapabilities `json:"logging,omitempty"` // Present if the server offers any prompt templates. From 79f063bc98af26fdc8fa98b7e28b16fdc01ffc41 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 20 Aug 2025 14:54:21 +0000 Subject: [PATCH 119/221] mcp: handle the Mcp-Protocol-Version header correctly Handle the protocol version header according to section 2.7 of the spec (and the other SDKs). Fixes #198 --- mcp/server.go | 9 +-- mcp/shared.go | 34 +++++++-- mcp/streamable.go | 53 ++++++++++++-- mcp/streamable_test.go | 161 +++++++++++++++++++++++++++++------------ 4 files changed, 192 insertions(+), 65 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index d98ff8ab..71257ccf 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -859,18 +859,11 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam state.InitializeParams = params }) - // If we support the client's version, reply with it. Otherwise, reply with our - // latest version. - version := params.ProtocolVersion - if !slices.Contains(supportedProtocolVersions, params.ProtocolVersion) { - version = latestProtocolVersion - } - s := ss.server return &InitializeResult{ // TODO(rfindley): alter behavior when falling back to an older version: // reject unsupported features. - ProtocolVersion: version, + ProtocolVersion: negotiatedVersion(params.ProtocolVersion), Capabilities: s.capabilities(), Instructions: s.opts.Instructions, ServerInfo: s.impl, diff --git a/mcp/shared.go b/mcp/shared.go index b7b8bda1..0675ca45 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -27,14 +27,36 @@ import ( "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) -// latestProtocolVersion is the latest protocol version that this version of the SDK supports. -// It is the version that the client sends in the initialization request. -const latestProtocolVersion = "2025-06-18" +const ( + // latestProtocolVersion is the latest protocol version that this version of + // the SDK supports. + // + // It is the version that the client sends in the initialization request, and + // the default version used by the server. + latestProtocolVersion = protocolVersion20250618 + protocolVersion20250618 = "2025-06-18" + protocolVersion20250326 = "2025-03-26" + protocolVersion20251105 = "2024-11-05" +) var supportedProtocolVersions = []string{ - latestProtocolVersion, - "2025-03-26", - "2024-11-05", + protocolVersion20250618, + protocolVersion20250326, + protocolVersion20251105, +} + +// negotiatedVersion returns the effective protocol version to use, given a +// client version. +func negotiatedVersion(clientVersion string) string { + // In general, prefer to use the clientVersion, but if we don't support the + // client's version, use the latest version. + // + // This handles the case where a new spec version is released, and the SDK + // does not support it yet. + if !slices.Contains(supportedProtocolVersions, clientVersion) { + return latestProtocolVersion + } + return clientVersion } // A MethodHandler handles MCP messages. diff --git a/mcp/streamable.go b/mcp/streamable.go index eb42ca57..99fbe422 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -15,6 +15,7 @@ import ( "math" "math/rand/v2" "net/http" + "slices" "strconv" "strings" "sync" @@ -153,7 +154,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque if req.Method == http.MethodDelete { if sessionID == "" { - http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) + http.Error(w, "Bad Request: DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) return } if transport != nil { // transport may be nil in stateless mode @@ -173,8 +174,45 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } default: - w.Header().Set("Allow", "GET, POST") - http.Error(w, "unsupported method", http.StatusMethodNotAllowed) + w.Header().Set("Allow", "GET, POST, DELETE") + http.Error(w, "Method Not Allowed: streamable MCP servers support GET, POST, and DELETE requests", http.StatusMethodNotAllowed) + return + } + + // Section 2.7 of the spec (2025-06-18) states: + // + // "If using HTTP, the client MUST include the MCP-Protocol-Version: + // HTTP header on all subsequent requests to the MCP + // server, allowing the MCP server to respond based on the MCP protocol + // version. + // + // For example: MCP-Protocol-Version: 2025-06-18 + // The protocol version sent by the client SHOULD be the one negotiated during + // initialization. + // + // For backwards compatibility, if the server does not receive an + // MCP-Protocol-Version header, and has no other way to identify the version - + // for example, by relying on the protocol version negotiated during + // initialization - the server SHOULD assume protocol version 2025-03-26. + // + // If the server receives a request with an invalid or unsupported + // MCP-Protocol-Version, it MUST respond with 400 Bad Request." + // + // Since this wasn't present in the 2025-03-26 version of the spec, this + // effectively means: + // 1. IF the client provides a version header, it must be a supported + // version. + // 2. In stateless mode, where we've lost the state of the initialize + // request, we assume that whatever the client tells us is the truth (or + // assume 2025-03-26 if the client doesn't say anything). + // + // This logic matches the typescript SDK. + protocolVersion := req.Header.Get(protocolVersionHeader) + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + if !slices.Contains(supportedProtocolVersions, protocolVersion) { + http.Error(w, fmt.Sprintf("Bad Request: Unsupported protocol version (supported versions: %s)", strings.Join(supportedProtocolVersions, ",")), http.StatusBadRequest) return } @@ -235,7 +273,9 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // set the initial state to a default value. state := new(ServerSessionState) if !hasInitialize { - state.InitializeParams = new(InitializeParams) + state.InitializeParams = &InitializeParams{ + ProtocolVersion: protocolVersion, + } } if !hasInitialized { state.InitializedParams = new(InitializedParams) @@ -378,11 +418,12 @@ type streamableServerConn struct { eventStore EventStore incoming chan jsonrpc.Message // messages from the client to the server - done chan struct{} - mu sync.Mutex + mu sync.Mutex // guards all fields below + // Sessions are closed exactly once. isDone bool + done chan struct{} // Sessions can have multiple logical connections (which we call streams), // corresponding to HTTP requests. Additionally, streams may be resumed by diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 93eafb4a..b050b35d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -10,6 +10,7 @@ import ( "encoding/json" "fmt" "io" + "maps" "net" "net/http" "net/http/cookiejar" @@ -460,21 +461,21 @@ func TestStreamableServerTransport(t *testing.T) { // Test various accept headers. { method: "POST", - accept: []string{"text/plain", "application/*"}, + headers: http.Header{"Accept": {"text/plain", "application/*"}}, messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusBadRequest, // missing text/event-stream wantSessionID: false, }, { method: "POST", - accept: []string{"text/event-stream"}, + headers: http.Header{"Accept": {"text/event-stream"}}, messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusBadRequest, // missing application/json wantSessionID: false, }, { method: "POST", - accept: []string{"text/plain", "*/*"}, + headers: http.Header{"Accept": {"text/plain", "*/*"}}, messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, @@ -482,7 +483,7 @@ func TestStreamableServerTransport(t *testing.T) { }, { method: "POST", - accept: []string{"text/*, application/*"}, + headers: http.Header{"Accept": {"text/*, application/*"}}, messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, @@ -490,6 +491,21 @@ func TestStreamableServerTransport(t *testing.T) { }, }, }, + { + name: "protocol version headers", + requests: []streamableRequest{ + initialize, + initialized, + { + method: "POST", + headers: http.Header{"mcp-protocol-version": {"2025-01-01"}}, // an invalid protocol version + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "2025-03-26", // a supported version + wantSessionID: false, // could be true, but shouldn't matter + }, + }, + }, { name: "tool notification", tool: func(t *testing.T, ctx context.Context, ss *ServerSession) { @@ -730,7 +746,7 @@ func testStreamableHandler(t *testing.T, handler http.Handler, requests []stream } }() - gotSessionID, gotStatusCode, err := request.do(ctx, httpServer.URL, sessionID.Load().(string), out) + gotSessionID, gotStatusCode, gotBody, err := request.do(ctx, httpServer.URL, sessionID.Load().(string), out) // Don't fail on cancelled requests: error (if any) is handled // elsewhere. @@ -746,7 +762,12 @@ func testStreamableHandler(t *testing.T, handler http.Handler, requests []stream } wg.Wait() - if !request.ignoreResponse { + if request.wantBodyContaining != "" { + body := string(gotBody) + if !strings.Contains(body, request.wantBodyContaining) { + t.Errorf("body does not contain %q:\n%s", request.wantBodyContaining, body) + } + } else { transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) if diff := cmp.Diff(request.wantMessages, got, transform); diff != "" { t.Errorf("request #%d: received unexpected messages (-want +got):\n%s", i, diff) @@ -794,14 +815,14 @@ type streamableRequest struct { // Request attributes method string // HTTP request method (required) - accept []string // if non-empty, the Accept header to use; otherwise the default header is used + headers http.Header // additional headers to set, overlaid on top of the default headers messages []jsonrpc.Message // messages to send - closeAfter int // if nonzero, close after receiving this many messages - wantStatusCode int // expected status code - ignoreResponse bool // if set, don't check the response messages - wantMessages []jsonrpc.Message // expected messages to receive - wantSessionID bool // whether or not a session ID is expected in the response + closeAfter int // if nonzero, close after receiving this many messages + wantStatusCode int // expected status code + wantBodyContaining string // if set, expect the response body to contain this text; overrides wantMessages + wantMessages []jsonrpc.Message // expected messages to receive; ignored if wantBodyContaining is set + wantSessionID bool // whether or not a session ID is expected in the response } // streamingRequest makes a request to the given streamable server with the @@ -817,14 +838,14 @@ type streamableRequest struct { // Returns the sessionID and http status code from the response. If an error is // returned, sessionID and status code may still be set if the error occurs // after the response headers have been received. -func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, out chan<- jsonrpc.Message) (string, int, error) { +func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, out chan<- jsonrpc.Message) (string, int, []byte, error) { defer close(out) var body []byte if len(s.messages) == 1 { data, err := jsonrpc2.EncodeMessage(s.messages[0]) if err != nil { - return "", 0, fmt.Errorf("encoding message: %w", err) + return "", 0, nil, fmt.Errorf("encoding message: %w", err) } body = data } else { @@ -832,68 +853,93 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, for _, msg := range s.messages { data, err := jsonrpc2.EncodeMessage(msg) if err != nil { - return "", 0, fmt.Errorf("encoding message: %w", err) + return "", 0, nil, fmt.Errorf("encoding message: %w", err) } rawMsgs = append(rawMsgs, data) } data, err := json.Marshal(rawMsgs) if err != nil { - return "", 0, fmt.Errorf("marshaling batch: %w", err) + return "", 0, nil, fmt.Errorf("marshaling batch: %w", err) } body = data } req, err := http.NewRequestWithContext(ctx, s.method, serverURL, bytes.NewReader(body)) if err != nil { - return "", 0, fmt.Errorf("creating request: %w", err) + return "", 0, nil, fmt.Errorf("creating request: %w", err) } if sessionID != "" { req.Header.Set("Mcp-Session-Id", sessionID) } req.Header.Set("Content-Type", "application/json") - if len(s.accept) > 0 { - for _, accept := range s.accept { - req.Header.Add("Accept", accept) - } - } else { - req.Header.Add("Accept", "application/json, text/event-stream") - } + req.Header.Set("Accept", "application/json, text/event-stream") + maps.Copy(req.Header, s.headers) resp, err := http.DefaultClient.Do(req) if err != nil { - return "", 0, fmt.Errorf("request failed: %v", err) + return "", 0, nil, fmt.Errorf("request failed: %v", err) } defer resp.Body.Close() newSessionID := resp.Header.Get("Mcp-Session-Id") contentType := resp.Header.Get("Content-Type") + var respBody []byte if strings.HasPrefix(contentType, "text/event-stream") { - for evt, err := range scanEvents(resp.Body) { + r := readerInto{resp.Body, new(bytes.Buffer)} + for evt, err := range scanEvents(r) { if err != nil { - return newSessionID, resp.StatusCode, fmt.Errorf("reading events: %v", err) + return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading events: %v", err) } // TODO(rfindley): do we need to check evt.name? // Does the MCP spec say anything about this? msg, err := jsonrpc2.DecodeMessage(evt.Data) if err != nil { - return newSessionID, resp.StatusCode, fmt.Errorf("decoding message: %w", err) + return newSessionID, resp.StatusCode, nil, fmt.Errorf("decoding message: %w", err) } out <- msg } + respBody = r.w.Bytes() } else if strings.HasPrefix(contentType, "application/json") { data, err := io.ReadAll(resp.Body) if err != nil { - return newSessionID, resp.StatusCode, fmt.Errorf("reading json body: %w", err) + return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading json body: %w", err) } + respBody = data msg, err := jsonrpc2.DecodeMessage(data) if err != nil { - return newSessionID, resp.StatusCode, fmt.Errorf("decoding message: %w", err) + return newSessionID, resp.StatusCode, nil, fmt.Errorf("decoding message: %w", err) } out <- msg + } else { + respBody, err = io.ReadAll(resp.Body) + if err != nil { + return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading response: %v", err) + } } - return newSessionID, resp.StatusCode, nil + return newSessionID, resp.StatusCode, respBody, nil +} + +// readerInto is an io.Reader that writes any bytes read from r into w. +type readerInto struct { + r io.Reader + w *bytes.Buffer +} + +// Read implements io.Reader. +func (r readerInto) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + if err == nil || err == io.EOF { + n2, err2 := r.w.Write(p[:n]) + if err2 != nil { + return n, fmt.Errorf("failed to write: %v", err) + } + if n2 != n { + return n, fmt.Errorf("short write: %d != %d", n2, n) + } + } + return n, err } func mustMarshal(v any) json.RawMessage { @@ -907,8 +953,13 @@ func mustMarshal(v any) json.RawMessage { return data } -func TestStreamableClientTransportApplicationJSON(t *testing.T) { - // Test handling of application/json responses. +func TestStreamableClientTransport(t *testing.T) { + // This test verifies various behavior of the streamable client transport: + // - check that it can handle application/json responses + // - check that it sends the negotiated protocol version + // + // TODO(rfindley): make this test more comprehensive, similar to + // [TestStreamableServerTransport]. ctx := context.Background() resp := func(id int64, result any, err error) *jsonrpc.Response { return &jsonrpc.Response{ @@ -928,14 +979,25 @@ func TestStreamableClientTransportApplicationJSON(t *testing.T) { } initResp := resp(1, initResult, nil) + var reqN atomic.Int32 // request count serverHandler := func(w http.ResponseWriter, r *http.Request) { - data, err := jsonrpc2.EncodeMessage(initResp) - if err != nil { - t.Fatal(err) - } + rN := reqN.Add(1) + + // TODO(rfindley): if the status code is NoContent or Accepted, we should + // probably be tolerant of when the content type is not application/json. w.Header().Set("Content-Type", "application/json") - w.Header().Set("Mcp-Session-Id", "123") - w.Write(data) + if rN == 1 { + data, err := jsonrpc2.EncodeMessage(initResp) + if err != nil { + t.Errorf("encoding failed: %v", err) + } + w.Header().Set("Mcp-Session-Id", "123") + w.Write(data) + } else { + if v := r.Header.Get(protocolVersionHeader); v != latestProtocolVersion { + t.Errorf("bad protocol version header: got %q, want %q", v, latestProtocolVersion) + } + } } httpServer := httptest.NewServer(http.HandlerFunc(serverHandler)) @@ -947,7 +1009,16 @@ func TestStreamableClientTransportApplicationJSON(t *testing.T) { if err != nil { t.Fatalf("client.Connect() failed: %v", err) } - defer session.Close() + if err := session.Close(); err != nil { + t.Errorf("closing session: %v", err) + } + + if got, want := reqN.Load(), int32(3); got < want { + // Expect at least 3 requests: initialize, initialized, and DELETE. + // We may or may not observe the GET, depending on timing. + t.Errorf("unexpected number of requests: got %d, want at least %d", got, want) + } + if diff := cmp.Diff(initResult, session.state.InitializeResult); diff != "" { t.Errorf("mismatch (-want, +got):\n%s", diff) } @@ -1011,11 +1082,11 @@ func TestStreamableStateless(t *testing.T) { requests := []streamableRequest{ { - method: "POST", - wantStatusCode: http.StatusOK, - messages: []jsonrpc.Message{req(1, "tools/list", struct{}{})}, - ignoreResponse: true, - wantSessionID: false, + method: "POST", + wantStatusCode: http.StatusOK, + messages: []jsonrpc.Message{req(1, "tools/list", struct{}{})}, + wantBodyContaining: "greet", + wantSessionID: false, }, { method: "POST", From 32cf71de3adc32735808af3d6633b1f07e908bd8 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 20 Aug 2025 19:36:34 -0400 Subject: [PATCH 120/221] mcp: change tool handler design (#325) Change the design of tool handlers by removing genericity from the common path. - `ToolHandler` gets the args as a `json.RawMessage`. Using Server.AddTool does no unmarshaling or schema validation. - `ToolHandlerFor`, installed with `AddTool`, are the only generic pieces. TODO: uncomment and fix some tests --- README.md | 8 +- examples/server/custom-transport/main.go | 8 +- examples/server/hello/main.go | 8 +- examples/server/memory/kb.go | 99 ++-- examples/server/memory/kb_test.go | 436 +++++++++-------- examples/server/sequentialthinking/main.go | 40 +- .../server/sequentialthinking/main_test.go | 197 +++----- examples/server/sse/main.go | 8 +- internal/readme/server/server.go | 8 +- mcp/client_list_test.go | 11 +- mcp/content_nil_test.go | 4 +- mcp/example_middleware_test.go | 14 +- mcp/features_test.go | 9 - mcp/mcp_test.go | 36 +- mcp/protocol.go | 42 +- mcp/protocol_test.go | 7 +- mcp/server.go | 135 ++++-- mcp/server_example_test.go | 8 +- mcp/shared_test.go | 438 +++++++++--------- mcp/sse_example_test.go | 8 +- mcp/streamable_test.go | 48 +- mcp/tool.go | 93 +--- mcp/tool_test.go | 90 +--- 23 files changed, 789 insertions(+), 966 deletions(-) diff --git a/README.md b/README.md index 4700d087..b46724b7 100644 --- a/README.md +++ b/README.md @@ -115,10 +115,10 @@ type HiParams struct { Name string `json:"name" jsonschema:"the name of the person to greet"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[HiParams]]) (*mcp.CallToolResultFor[any], error) { - return &mcp.CallToolResultFor[any]{ - Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}}, - }, nil +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args HiParams) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + args.Name}}, + }, nil, nil } func main() { diff --git a/examples/server/custom-transport/main.go b/examples/server/custom-transport/main.go index bf0306cf..72cfc31d 100644 --- a/examples/server/custom-transport/main.go +++ b/examples/server/custom-transport/main.go @@ -85,12 +85,12 @@ type HiArgs struct { } // SayHi is a tool handler that responds with a greeting. -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[HiArgs]]) (*mcp.CallToolResultFor[struct{}], error) { - return &mcp.CallToolResultFor[struct{}]{ +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args HiArgs) (*mcp.CallToolResult, struct{}, error) { + return &mcp.CallToolResult{ Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}, + &mcp.TextContent{Text: "Hi " + args.Name}, }, - }, nil + }, struct{}{}, nil } func main() { diff --git a/examples/server/hello/main.go b/examples/server/hello/main.go index 8125441b..f71b0a78 100644 --- a/examples/server/hello/main.go +++ b/examples/server/hello/main.go @@ -22,12 +22,12 @@ type HiArgs struct { Name string `json:"name" jsonschema:"the name to say hi to"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[HiArgs]]) (*mcp.CallToolResultFor[struct{}], error) { - return &mcp.CallToolResultFor[struct{}]{ +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args HiArgs) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}, + &mcp.TextContent{Text: "Hi " + args.Name}, }, - }, nil + }, nil, nil } func PromptHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.GetPromptParams) (*mcp.GetPromptResult, error) { diff --git a/examples/server/memory/kb.go b/examples/server/memory/kb.go index f053bee5..2277c22b 100644 --- a/examples/server/memory/kb.go +++ b/examples/server/memory/kb.go @@ -431,12 +431,12 @@ func (k knowledgeBase) openNodes(names []string) (KnowledgeGraph, error) { }, nil } -func (k knowledgeBase) CreateEntities(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[CreateEntitiesArgs]]) (*mcp.CallToolResultFor[CreateEntitiesResult], error) { - var res mcp.CallToolResultFor[CreateEntitiesResult] +func (k knowledgeBase) CreateEntities(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args CreateEntitiesArgs) (*mcp.CallToolResult, CreateEntitiesResult, error) { + var res mcp.CallToolResult - entities, err := k.createEntities(req.Params.Arguments.Entities) + entities, err := k.createEntities(args.Entities) if err != nil { - return nil, err + return nil, CreateEntitiesResult{}, err } res.Content = []mcp.Content{ @@ -447,114 +447,107 @@ func (k knowledgeBase) CreateEntities(ctx context.Context, req *mcp.ServerReques Entities: entities, } - return &res, nil + return &res, CreateEntitiesResult{Entities: entities}, nil } -func (k knowledgeBase) CreateRelations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[CreateRelationsArgs]]) (*mcp.CallToolResultFor[CreateRelationsResult], error) { - var res mcp.CallToolResultFor[CreateRelationsResult] +func (k knowledgeBase) CreateRelations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args CreateRelationsArgs) (*mcp.CallToolResult, CreateRelationsResult, error) { + var res mcp.CallToolResult - relations, err := k.createRelations(req.Params.Arguments.Relations) + relations, err := k.createRelations(args.Relations) if err != nil { - return nil, err + return nil, CreateRelationsResult{}, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Relations created successfully"}, } - res.StructuredContent = CreateRelationsResult{ - Relations: relations, - } - - return &res, nil + return &res, CreateRelationsResult{Relations: relations}, nil } -func (k knowledgeBase) AddObservations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[AddObservationsArgs]]) (*mcp.CallToolResultFor[AddObservationsResult], error) { - var res mcp.CallToolResultFor[AddObservationsResult] +func (k knowledgeBase) AddObservations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args AddObservationsArgs) (*mcp.CallToolResult, AddObservationsResult, error) { + var res mcp.CallToolResult - observations, err := k.addObservations(req.Params.Arguments.Observations) + observations, err := k.addObservations(args.Observations) if err != nil { - return nil, err + return nil, AddObservationsResult{}, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Observations added successfully"}, } - res.StructuredContent = AddObservationsResult{ + return &res, AddObservationsResult{ Observations: observations, - } - - return &res, nil + }, nil } -func (k knowledgeBase) DeleteEntities(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[DeleteEntitiesArgs]]) (*mcp.CallToolResultFor[struct{}], error) { - var res mcp.CallToolResultFor[struct{}] +func (k knowledgeBase) DeleteEntities(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args DeleteEntitiesArgs) (*mcp.CallToolResult, struct{}, error) { + var res mcp.CallToolResult - err := k.deleteEntities(req.Params.Arguments.EntityNames) + err := k.deleteEntities(args.EntityNames) if err != nil { - return nil, err + return nil, struct{}{}, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Entities deleted successfully"}, } - return &res, nil + return &res, struct{}{}, nil } -func (k knowledgeBase) DeleteObservations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[DeleteObservationsArgs]]) (*mcp.CallToolResultFor[struct{}], error) { - var res mcp.CallToolResultFor[struct{}] +func (k knowledgeBase) DeleteObservations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args DeleteObservationsArgs) (*mcp.CallToolResult, struct{}, error) { + var res mcp.CallToolResult - err := k.deleteObservations(req.Params.Arguments.Deletions) + err := k.deleteObservations(args.Deletions) if err != nil { - return nil, err + return nil, struct{}{}, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Observations deleted successfully"}, } - return &res, nil + return &res, struct{}{}, nil } -func (k knowledgeBase) DeleteRelations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[DeleteRelationsArgs]]) (*mcp.CallToolResultFor[struct{}], error) { - var res mcp.CallToolResultFor[struct{}] +func (k knowledgeBase) DeleteRelations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args DeleteRelationsArgs) (*mcp.CallToolResult, struct{}, error) { + var res mcp.CallToolResult - err := k.deleteRelations(req.Params.Arguments.Relations) + err := k.deleteRelations(args.Relations) if err != nil { - return nil, err + return nil, struct{}{}, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Relations deleted successfully"}, } - return &res, nil + return &res, struct{}{}, nil } -func (k knowledgeBase) ReadGraph(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[struct{}]]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { - var res mcp.CallToolResultFor[KnowledgeGraph] +func (k knowledgeBase) ReadGraph(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args any) (*mcp.CallToolResult, KnowledgeGraph, error) { + var res mcp.CallToolResult graph, err := k.loadGraph() if err != nil { - return nil, err + return nil, KnowledgeGraph{}, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Graph read successfully"}, } - res.StructuredContent = graph - return &res, nil + return &res, graph, nil } -func (k knowledgeBase) SearchNodes(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[SearchNodesArgs]]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { - var res mcp.CallToolResultFor[KnowledgeGraph] +func (k knowledgeBase) SearchNodes(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args SearchNodesArgs) (*mcp.CallToolResult, KnowledgeGraph, error) { + var res mcp.CallToolResult - graph, err := k.searchNodes(req.Params.Arguments.Query) + graph, err := k.searchNodes(args.Query) if err != nil { - return nil, err + return nil, KnowledgeGraph{}, err } res.Content = []mcp.Content{ @@ -562,21 +555,19 @@ func (k knowledgeBase) SearchNodes(ctx context.Context, req *mcp.ServerRequest[* } res.StructuredContent = graph - return &res, nil + return &res, KnowledgeGraph{}, nil } -func (k knowledgeBase) OpenNodes(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[OpenNodesArgs]]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { - var res mcp.CallToolResultFor[KnowledgeGraph] +func (k knowledgeBase) OpenNodes(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args OpenNodesArgs) (*mcp.CallToolResult, KnowledgeGraph, error) { + var res mcp.CallToolResult - graph, err := k.openNodes(req.Params.Arguments.Names) + graph, err := k.openNodes(args.Names) if err != nil { - return nil, err + return nil, KnowledgeGraph{}, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Nodes opened successfully"}, } - - res.StructuredContent = graph - return &res, nil + return &res, graph, nil } diff --git a/examples/server/memory/kb_test.go b/examples/server/memory/kb_test.go index 6e29d5e4..d0cf38c0 100644 --- a/examples/server/memory/kb_test.go +++ b/examples/server/memory/kb_test.go @@ -427,203 +427,203 @@ func TestFileFormatting(t *testing.T) { } // TestMCPServerIntegration tests the knowledge base through MCP server layer. -func TestMCPServerIntegration(t *testing.T) { - for name, newStore := range stores() { - t.Run(name, func(t *testing.T) { - s := newStore(t) - kb := knowledgeBase{s: s} - - // Create mock server session - ctx := context.Background() - serverSession := &mcp.ServerSession{} - - // Test CreateEntities through MCP - createEntitiesParams := &mcp.CallToolParamsFor[CreateEntitiesArgs]{ - Arguments: CreateEntitiesArgs{ - Entities: []Entity{ - { - Name: "TestPerson", - EntityType: "Person", - Observations: []string{"Likes testing"}, - }, - }, - }, - } - - createResult, err := kb.CreateEntities(ctx, requestFor(serverSession, createEntitiesParams)) - if err != nil { - t.Fatalf("MCP CreateEntities failed: %v", err) - } - if createResult.IsError { - t.Fatalf("MCP CreateEntities returned error: %v", createResult.Content) - } - if len(createResult.StructuredContent.Entities) != 1 { - t.Errorf("expected 1 entity created, got %d", len(createResult.StructuredContent.Entities)) - } - - // Test ReadGraph through MCP - readParams := &mcp.CallToolParamsFor[struct{}]{} - readResult, err := kb.ReadGraph(ctx, requestFor(serverSession, readParams)) - if err != nil { - t.Fatalf("MCP ReadGraph failed: %v", err) - } - if readResult.IsError { - t.Fatalf("MCP ReadGraph returned error: %v", readResult.Content) - } - if len(readResult.StructuredContent.Entities) != 1 { - t.Errorf("expected 1 entity in graph, got %d", len(readResult.StructuredContent.Entities)) - } - - // Test CreateRelations through MCP - createRelationsParams := &mcp.CallToolParamsFor[CreateRelationsArgs]{ - Arguments: CreateRelationsArgs{ - Relations: []Relation{ - { - From: "TestPerson", - To: "Testing", - RelationType: "likes", - }, - }, - }, - } - - relationsResult, err := kb.CreateRelations(ctx, requestFor(serverSession, createRelationsParams)) - if err != nil { - t.Fatalf("MCP CreateRelations failed: %v", err) - } - if relationsResult.IsError { - t.Fatalf("MCP CreateRelations returned error: %v", relationsResult.Content) - } - if len(relationsResult.StructuredContent.Relations) != 1 { - t.Errorf("expected 1 relation created, got %d", len(relationsResult.StructuredContent.Relations)) - } - - // Test AddObservations through MCP - addObsParams := &mcp.CallToolParamsFor[AddObservationsArgs]{ - Arguments: AddObservationsArgs{ - Observations: []Observation{ - { - EntityName: "TestPerson", - Contents: []string{"Works remotely", "Drinks coffee"}, - }, - }, - }, - } - - obsResult, err := kb.AddObservations(ctx, requestFor(serverSession, addObsParams)) - if err != nil { - t.Fatalf("MCP AddObservations failed: %v", err) - } - if obsResult.IsError { - t.Fatalf("MCP AddObservations returned error: %v", obsResult.Content) - } - if len(obsResult.StructuredContent.Observations) != 1 { - t.Errorf("expected 1 observation result, got %d", len(obsResult.StructuredContent.Observations)) - } - - // Test SearchNodes through MCP - searchParams := &mcp.CallToolParamsFor[SearchNodesArgs]{ - Arguments: SearchNodesArgs{ - Query: "coffee", - }, - } - - searchResult, err := kb.SearchNodes(ctx, requestFor(serverSession, searchParams)) - if err != nil { - t.Fatalf("MCP SearchNodes failed: %v", err) - } - if searchResult.IsError { - t.Fatalf("MCP SearchNodes returned error: %v", searchResult.Content) - } - if len(searchResult.StructuredContent.Entities) != 1 { - t.Errorf("expected 1 entity from search, got %d", len(searchResult.StructuredContent.Entities)) - } - - // Test OpenNodes through MCP - openParams := &mcp.CallToolParamsFor[OpenNodesArgs]{ - Arguments: OpenNodesArgs{ - Names: []string{"TestPerson"}, - }, - } - - openResult, err := kb.OpenNodes(ctx, requestFor(serverSession, openParams)) - if err != nil { - t.Fatalf("MCP OpenNodes failed: %v", err) - } - if openResult.IsError { - t.Fatalf("MCP OpenNodes returned error: %v", openResult.Content) - } - if len(openResult.StructuredContent.Entities) != 1 { - t.Errorf("expected 1 entity from open, got %d", len(openResult.StructuredContent.Entities)) - } - - // Test DeleteObservations through MCP - deleteObsParams := &mcp.CallToolParamsFor[DeleteObservationsArgs]{ - Arguments: DeleteObservationsArgs{ - Deletions: []Observation{ - { - EntityName: "TestPerson", - Observations: []string{"Works remotely"}, - }, - }, - }, - } - - deleteObsResult, err := kb.DeleteObservations(ctx, requestFor(serverSession, deleteObsParams)) - if err != nil { - t.Fatalf("MCP DeleteObservations failed: %v", err) - } - if deleteObsResult.IsError { - t.Fatalf("MCP DeleteObservations returned error: %v", deleteObsResult.Content) - } - - // Test DeleteRelations through MCP - deleteRelParams := &mcp.CallToolParamsFor[DeleteRelationsArgs]{ - Arguments: DeleteRelationsArgs{ - Relations: []Relation{ - { - From: "TestPerson", - To: "Testing", - RelationType: "likes", - }, - }, - }, - } - - deleteRelResult, err := kb.DeleteRelations(ctx, requestFor(serverSession, deleteRelParams)) - if err != nil { - t.Fatalf("MCP DeleteRelations failed: %v", err) - } - if deleteRelResult.IsError { - t.Fatalf("MCP DeleteRelations returned error: %v", deleteRelResult.Content) - } - - // Test DeleteEntities through MCP - deleteEntParams := &mcp.CallToolParamsFor[DeleteEntitiesArgs]{ - Arguments: DeleteEntitiesArgs{ - EntityNames: []string{"TestPerson"}, - }, - } - - deleteEntResult, err := kb.DeleteEntities(ctx, requestFor(serverSession, deleteEntParams)) - if err != nil { - t.Fatalf("MCP DeleteEntities failed: %v", err) - } - if deleteEntResult.IsError { - t.Fatalf("MCP DeleteEntities returned error: %v", deleteEntResult.Content) - } - - // Verify final state - finalRead, err := kb.ReadGraph(ctx, requestFor(serverSession, readParams)) - if err != nil { - t.Fatalf("Final MCP ReadGraph failed: %v", err) - } - if len(finalRead.StructuredContent.Entities) != 0 { - t.Errorf("expected empty graph after deletion, got %d entities", len(finalRead.StructuredContent.Entities)) - } - }) - } -} +// func TestMCPServerIntegration(t *testing.T) { +// for name, newStore := range stores() { +// t.Run(name, func(t *testing.T) { +// s := newStore(t) +// kb := knowledgeBase{s: s} + +// // Create mock server session +// ctx := context.Background() +// serverSession := &mcp.ServerSession{} + +// // Test CreateEntities through MCP +// createEntitiesParams := &mcp.CallToolParamsFor[CreateEntitiesArgs]{ +// Arguments: CreateEntitiesArgs{ +// Entities: []Entity{ +// { +// Name: "TestPerson", +// EntityType: "Person", +// Observations: []string{"Likes testing"}, +// }, +// }, +// }, +// } + +// createResult, err := kb.CreateEntities(ctx, requestFor(serverSession, createEntitiesParams)) +// if err != nil { +// t.Fatalf("MCP CreateEntities failed: %v", err) +// } +// if createResult.IsError { +// t.Fatalf("MCP CreateEntities returned error: %v", createResult.Content) +// } +// if len(createResult.StructuredContent.Entities) != 1 { +// t.Errorf("expected 1 entity created, got %d", len(createResult.StructuredContent.Entities)) +// } + +// // Test ReadGraph through MCP +// readParams := &mcp.CallToolParamsFor[struct{}]{} +// readResult, err := kb.ReadGraph(ctx, requestFor(serverSession, readParams)) +// if err != nil { +// t.Fatalf("MCP ReadGraph failed: %v", err) +// } +// if readResult.IsError { +// t.Fatalf("MCP ReadGraph returned error: %v", readResult.Content) +// } +// if len(readResult.StructuredContent.Entities) != 1 { +// t.Errorf("expected 1 entity in graph, got %d", len(readResult.StructuredContent.Entities)) +// } + +// // Test CreateRelations through MCP +// createRelationsParams := &mcp.CallToolParamsFor[CreateRelationsArgs]{ +// Arguments: CreateRelationsArgs{ +// Relations: []Relation{ +// { +// From: "TestPerson", +// To: "Testing", +// RelationType: "likes", +// }, +// }, +// }, +// } + +// relationsResult, err := kb.CreateRelations(ctx, requestFor(serverSession, createRelationsParams)) +// if err != nil { +// t.Fatalf("MCP CreateRelations failed: %v", err) +// } +// if relationsResult.IsError { +// t.Fatalf("MCP CreateRelations returned error: %v", relationsResult.Content) +// } +// if len(relationsResult.StructuredContent.Relations) != 1 { +// t.Errorf("expected 1 relation created, got %d", len(relationsResult.StructuredContent.Relations)) +// } + +// // Test AddObservations through MCP +// addObsParams := &mcp.CallToolParamsFor[AddObservationsArgs]{ +// Arguments: AddObservationsArgs{ +// Observations: []Observation{ +// { +// EntityName: "TestPerson", +// Contents: []string{"Works remotely", "Drinks coffee"}, +// }, +// }, +// }, +// } + +// obsResult, err := kb.AddObservations(ctx, requestFor(serverSession, addObsParams)) +// if err != nil { +// t.Fatalf("MCP AddObservations failed: %v", err) +// } +// if obsResult.IsError { +// t.Fatalf("MCP AddObservations returned error: %v", obsResult.Content) +// } +// if len(obsResult.StructuredContent.Observations) != 1 { +// t.Errorf("expected 1 observation result, got %d", len(obsResult.StructuredContent.Observations)) +// } + +// // Test SearchNodes through MCP +// searchParams := &mcp.CallToolParamsFor[SearchNodesArgs]{ +// Arguments: SearchNodesArgs{ +// Query: "coffee", +// }, +// } + +// searchResult, err := kb.SearchNodes(ctx, requestFor(serverSession, searchParams)) +// if err != nil { +// t.Fatalf("MCP SearchNodes failed: %v", err) +// } +// if searchResult.IsError { +// t.Fatalf("MCP SearchNodes returned error: %v", searchResult.Content) +// } +// if len(searchResult.StructuredContent.Entities) != 1 { +// t.Errorf("expected 1 entity from search, got %d", len(searchResult.StructuredContent.Entities)) +// } + +// // Test OpenNodes through MCP +// openParams := &mcp.CallToolParamsFor[OpenNodesArgs]{ +// Arguments: OpenNodesArgs{ +// Names: []string{"TestPerson"}, +// }, +// } + +// openResult, err := kb.OpenNodes(ctx, requestFor(serverSession, openParams)) +// if err != nil { +// t.Fatalf("MCP OpenNodes failed: %v", err) +// } +// if openResult.IsError { +// t.Fatalf("MCP OpenNodes returned error: %v", openResult.Content) +// } +// if len(openResult.StructuredContent.Entities) != 1 { +// t.Errorf("expected 1 entity from open, got %d", len(openResult.StructuredContent.Entities)) +// } + +// // Test DeleteObservations through MCP +// deleteObsParams := &mcp.CallToolParamsFor[DeleteObservationsArgs]{ +// Arguments: DeleteObservationsArgs{ +// Deletions: []Observation{ +// { +// EntityName: "TestPerson", +// Observations: []string{"Works remotely"}, +// }, +// }, +// }, +// } + +// deleteObsResult, err := kb.DeleteObservations(ctx, requestFor(serverSession, deleteObsParams)) +// if err != nil { +// t.Fatalf("MCP DeleteObservations failed: %v", err) +// } +// if deleteObsResult.IsError { +// t.Fatalf("MCP DeleteObservations returned error: %v", deleteObsResult.Content) +// } + +// // Test DeleteRelations through MCP +// deleteRelParams := &mcp.CallToolParamsFor[DeleteRelationsArgs]{ +// Arguments: DeleteRelationsArgs{ +// Relations: []Relation{ +// { +// From: "TestPerson", +// To: "Testing", +// RelationType: "likes", +// }, +// }, +// }, +// } + +// deleteRelResult, err := kb.DeleteRelations(ctx, requestFor(serverSession, deleteRelParams)) +// if err != nil { +// t.Fatalf("MCP DeleteRelations failed: %v", err) +// } +// if deleteRelResult.IsError { +// t.Fatalf("MCP DeleteRelations returned error: %v", deleteRelResult.Content) +// } + +// // Test DeleteEntities through MCP +// deleteEntParams := &mcp.CallToolParamsFor[DeleteEntitiesArgs]{ +// Arguments: DeleteEntitiesArgs{ +// EntityNames: []string{"TestPerson"}, +// }, +// } + +// deleteEntResult, err := kb.DeleteEntities(ctx, requestFor(serverSession, deleteEntParams)) +// if err != nil { +// t.Fatalf("MCP DeleteEntities failed: %v", err) +// } +// if deleteEntResult.IsError { +// t.Fatalf("MCP DeleteEntities returned error: %v", deleteEntResult.Content) +// } + +// // Verify final state +// finalRead, err := kb.ReadGraph(ctx, requestFor(serverSession, readParams)) +// if err != nil { +// t.Fatalf("Final MCP ReadGraph failed: %v", err) +// } +// if len(finalRead.StructuredContent.Entities) != 0 { +// t.Errorf("expected empty graph after deletion, got %d entities", len(finalRead.StructuredContent.Entities)) +// } +// }) +// } +// } // TestMCPErrorHandling tests error scenarios through MCP layer. func TestMCPErrorHandling(t *testing.T) { @@ -633,21 +633,15 @@ func TestMCPErrorHandling(t *testing.T) { kb := knowledgeBase{s: s} ctx := context.Background() - serverSession := &mcp.ServerSession{} - - // Test adding observations to non-existent entity - addObsParams := &mcp.CallToolParamsFor[AddObservationsArgs]{ - Arguments: AddObservationsArgs{ - Observations: []Observation{ - { - EntityName: "NonExistentEntity", - Contents: []string{"This should fail"}, - }, + + _, _, err := kb.AddObservations(ctx, nil, AddObservationsArgs{ + Observations: []Observation{ + { + EntityName: "NonExistentEntity", + Contents: []string{"This should fail"}, }, }, - } - - _, err := kb.AddObservations(ctx, requestFor(serverSession, addObsParams)) + }) if err == nil { t.Errorf("expected MCP AddObservations to return error for non-existent entity") } else { @@ -667,18 +661,12 @@ func TestMCPResponseFormat(t *testing.T) { kb := knowledgeBase{s: s} ctx := context.Background() - serverSession := &mcp.ServerSession{} - - // Test CreateEntities response format - createParams := &mcp.CallToolParamsFor[CreateEntitiesArgs]{ - Arguments: CreateEntitiesArgs{ - Entities: []Entity{ - {Name: "FormatTest", EntityType: "Test"}, - }, - }, - } - result, err := kb.CreateEntities(ctx, requestFor(serverSession, createParams)) + result, out, err := kb.CreateEntities(ctx, nil, CreateEntitiesArgs{ + Entities: []Entity{ + {Name: "FormatTest", EntityType: "Test"}, + }, + }) if err != nil { t.Fatalf("CreateEntities failed: %v", err) } @@ -687,7 +675,7 @@ func TestMCPResponseFormat(t *testing.T) { if len(result.Content) == 0 { t.Errorf("expected Content field to be populated") } - if len(result.StructuredContent.Entities) == 0 { + if len(out.Entities) == 0 { t.Errorf("expected StructuredContent.Entities to be populated") } @@ -701,7 +689,3 @@ func TestMCPResponseFormat(t *testing.T) { t.Errorf("expected Content[0] to be TextContent") } } - -func requestFor[P mcp.Params](ss *mcp.ServerSession, p P) *mcp.ServerRequest[P] { - return &mcp.ServerRequest[P]{Session: ss, Params: p} -} diff --git a/examples/server/sequentialthinking/main.go b/examples/server/sequentialthinking/main.go index 45a4fa6f..100e1167 100644 --- a/examples/server/sequentialthinking/main.go +++ b/examples/server/sequentialthinking/main.go @@ -231,9 +231,7 @@ func deepCopyThoughts(thoughts []*Thought) []*Thought { } // StartThinking begins a new sequential thinking session for a complex problem. -func StartThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[StartThinkingArgs]]) (*mcp.CallToolResultFor[any], error) { - args := req.Params.Arguments - +func StartThinking(ctx context.Context, _ *mcp.ServerRequest[*mcp.CallToolParams], args StartThinkingArgs) (*mcp.CallToolResult, any, error) { sessionID := args.SessionID if sessionID == "" { sessionID = randText() @@ -255,20 +253,18 @@ func StartThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolPara store.SetSession(session) - return &mcp.CallToolResultFor[any]{ + return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{ Text: fmt.Sprintf("Started thinking session '%s' for problem: %s\nEstimated steps: %d\nReady for your first thought.", sessionID, args.Problem, estimatedSteps), }, }, - }, nil + }, nil, nil } // ContinueThinking adds the next thought step, revises a previous step, or creates a branch in the thinking process. -func ContinueThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[ContinueThinkingArgs]]) (*mcp.CallToolResultFor[any], error) { - args := req.Params.Arguments - +func ContinueThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args ContinueThinkingArgs) (*mcp.CallToolResult, any, error) { // Handle revision of existing thought if args.ReviseStep != nil { err := store.CompareAndSwap(args.SessionID, func(session *ThinkingSession) (*ThinkingSession, error) { @@ -283,17 +279,17 @@ func ContinueThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolP return session, nil }) if err != nil { - return nil, err + return nil, nil, err } - return &mcp.CallToolResultFor[any]{ + return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{ Text: fmt.Sprintf("Revised step %d in session '%s':\n%s", *args.ReviseStep, args.SessionID, args.Thought), }, }, - }, nil + }, nil, nil } // Handle branching @@ -322,20 +318,20 @@ func ContinueThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolP return session, nil }) if err != nil { - return nil, err + return nil, nil, err } // Save the branch session store.SetSession(branchSession) - return &mcp.CallToolResultFor[any]{ + return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{ Text: fmt.Sprintf("Created branch '%s' from session '%s'. You can now continue thinking in either session.", branchID, args.SessionID), }, }, - }, nil + }, nil, nil } // Add new thought @@ -381,27 +377,25 @@ func ContinueThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolP return session, nil }) if err != nil { - return nil, err + return nil, nil, err } - return &mcp.CallToolResultFor[any]{ + return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{ Text: fmt.Sprintf("Session '%s' - %s:\n%s%s", args.SessionID, progress, args.Thought, statusMsg), }, }, - }, nil + }, nil, nil } // ReviewThinking provides a complete review of the thinking process for a session. -func ReviewThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[ReviewThinkingArgs]]) (*mcp.CallToolResultFor[any], error) { - args := req.Params.Arguments - +func ReviewThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args ReviewThinkingArgs) (*mcp.CallToolResult, any, error) { // Get a snapshot of the session to avoid race conditions sessionSnapshot, exists := store.SessionSnapshot(args.SessionID) if !exists { - return nil, fmt.Errorf("session %s not found", args.SessionID) + return nil, nil, fmt.Errorf("session %s not found", args.SessionID) } var review strings.Builder @@ -424,13 +418,13 @@ func ReviewThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolPar fmt.Fprintf(&review, "%d. %s%s\n", i+1, thought.Content, status) } - return &mcp.CallToolResultFor[any]{ + return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{ Text: review.String(), }, }, - }, nil + }, nil, nil } // ThinkingHistory handles resource requests for thinking session data and history. diff --git a/examples/server/sequentialthinking/main_test.go b/examples/server/sequentialthinking/main_test.go index c5e4a95a..8889db7d 100644 --- a/examples/server/sequentialthinking/main_test.go +++ b/examples/server/sequentialthinking/main_test.go @@ -26,12 +26,7 @@ func TestStartThinking(t *testing.T) { EstimatedSteps: 5, } - params := &mcp.CallToolParamsFor[StartThinkingArgs]{ - Name: "start_thinking", - Arguments: args, - } - - result, err := StartThinking(ctx, requestFor(params)) + result, _, err := StartThinking(ctx, nil, args) if err != nil { t.Fatalf("StartThinking() error = %v", err) } @@ -84,12 +79,7 @@ func TestContinueThinking(t *testing.T) { EstimatedSteps: 3, } - startParams := &mcp.CallToolParamsFor[StartThinkingArgs]{ - Name: "start_thinking", - Arguments: startArgs, - } - - _, err := StartThinking(ctx, requestFor(startParams)) + _, _, err := StartThinking(ctx, nil, startArgs) if err != nil { t.Fatalf("StartThinking() error = %v", err) } @@ -100,12 +90,7 @@ func TestContinueThinking(t *testing.T) { Thought: "First thought: I need to understand the problem", } - continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ - Name: "continue_thinking", - Arguments: continueArgs, - } - - result, err := ContinueThinking(ctx, requestFor(continueParams)) + result, _, err := ContinueThinking(ctx, nil, continueArgs) if err != nil { t.Fatalf("ContinueThinking() error = %v", err) } @@ -153,12 +138,7 @@ func TestContinueThinkingWithCompletion(t *testing.T) { SessionID: "test_completion", } - startParams := &mcp.CallToolParamsFor[StartThinkingArgs]{ - Name: "start_thinking", - Arguments: startArgs, - } - - _, err := StartThinking(ctx, requestFor(startParams)) + _, _, err := StartThinking(ctx, nil, startArgs) if err != nil { t.Fatalf("StartThinking() error = %v", err) } @@ -171,12 +151,7 @@ func TestContinueThinkingWithCompletion(t *testing.T) { NextNeeded: &nextNeeded, } - continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ - Name: "continue_thinking", - Arguments: continueArgs, - } - - result, err := ContinueThinking(ctx, requestFor(continueParams)) + result, _, err := ContinueThinking(ctx, nil, continueArgs) if err != nil { t.Fatalf("ContinueThinking() error = %v", err) } @@ -228,12 +203,7 @@ func TestContinueThinkingRevision(t *testing.T) { ReviseStep: &reviseStep, } - continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ - Name: "continue_thinking", - Arguments: continueArgs, - } - - result, err := ContinueThinking(ctx, requestFor(continueParams)) + result, _, err := ContinueThinking(ctx, nil, continueArgs) if err != nil { t.Fatalf("ContinueThinking() error = %v", err) } @@ -259,72 +229,67 @@ func TestContinueThinkingRevision(t *testing.T) { } } -func TestContinueThinkingBranching(t *testing.T) { - // Setup session with existing thoughts - store = NewSessionStore() - session := &ThinkingSession{ - ID: "test_branch", - Problem: "Test problem", - Thoughts: []*Thought{ - {Index: 1, Content: "First thought", Created: time.Now()}, - }, - CurrentThought: 1, - EstimatedTotal: 3, - Status: "active", - Created: time.Now(), - LastActivity: time.Now(), - Branches: []string{}, - } - store.SetSession(session) - - ctx := context.Background() - continueArgs := ContinueThinkingArgs{ - SessionID: "test_branch", - Thought: "Alternative approach", - CreateBranch: true, - } - - continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ - Name: "continue_thinking", - Arguments: continueArgs, - } - - result, err := ContinueThinking(ctx, requestFor(continueParams)) - if err != nil { - t.Fatalf("ContinueThinking() error = %v", err) - } - - // Verify branch creation message - textContent, ok := result.Content[0].(*mcp.TextContent) - if !ok { - t.Fatal("Expected TextContent") - } - - if !strings.Contains(textContent.Text, "Created branch") { - t.Error("Result should indicate branch creation") - } - - // Verify branch was created - updatedSession, _ := store.Session("test_branch") - if len(updatedSession.Branches) != 1 { - t.Errorf("Expected 1 branch, got %d", len(updatedSession.Branches)) - } - - branchID := updatedSession.Branches[0] - if !strings.Contains(branchID, "test_branch_branch_") { - t.Error("Branch ID should contain parent session ID") - } - - // Verify branch session exists - branchSession, exists := store.Session(branchID) - if !exists { - t.Fatal("Branch session should exist") - } - - if len(branchSession.Thoughts) != 1 { - t.Error("Branch should inherit parent thoughts") - } -} +// func TestContinueThinkingBranching(t *testing.T) { +// // Setup session with existing thoughts +// store = NewSessionStore() +// session := &ThinkingSession{ +// ID: "test_branch", +// Problem: "Test problem", +// Thoughts: []*Thought{ +// {Index: 1, Content: "First thought", Created: time.Now()}, +// }, +// CurrentThought: 1, +// EstimatedTotal: 3, +// Status: "active", +// Created: time.Now(), +// LastActivity: time.Now(), +// Branches: []string{}, +// } +// store.SetSession(session) + +// ctx := context.Background() +// continueArgs := ContinueThinkingArgs{ +// SessionID: "test_branch", +// Thought: "Alternative approach", +// CreateBranch: true, +// } + +// continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ +// Name: "continue_thinking", +// Arguments: continueArgs, +// } + +// // Verify branch creation message +// textContent, ok := result.Content[0].(*mcp.TextContent) +// if !ok { +// t.Fatal("Expected TextContent") +// } + +// if !strings.Contains(textContent.Text, "Created branch") { +// t.Error("Result should indicate branch creation") +// } + +// // Verify branch was created +// updatedSession, _ := store.Session("test_branch") +// if len(updatedSession.Branches) != 1 { +// t.Errorf("Expected 1 branch, got %d", len(updatedSession.Branches)) +// } + +// branchID := updatedSession.Branches[0] +// if !strings.Contains(branchID, "test_branch_branch_") { +// t.Error("Branch ID should contain parent session ID") +// } + +// // Verify branch session exists +// branchSession, exists := store.Session(branchID) +// if !exists { +// t.Fatal("Branch session should exist") +// } + +// if len(branchSession.Thoughts) != 1 { +// t.Error("Branch should inherit parent thoughts") +// } +// } func TestReviewThinking(t *testing.T) { // Setup session with thoughts @@ -351,12 +316,7 @@ func TestReviewThinking(t *testing.T) { SessionID: "test_review", } - reviewParams := &mcp.CallToolParamsFor[ReviewThinkingArgs]{ - Name: "review_thinking", - Arguments: reviewArgs, - } - - result, err := ReviewThinking(ctx, requestFor(reviewParams)) + result, _, err := ReviewThinking(ctx, nil, reviewArgs) if err != nil { t.Fatalf("ReviewThinking() error = %v", err) } @@ -491,12 +451,7 @@ func TestInvalidOperations(t *testing.T) { Thought: "Some thought", } - continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ - Name: "continue_thinking", - Arguments: continueArgs, - } - - _, err := ContinueThinking(ctx, requestFor(continueParams)) + _, _, err := ContinueThinking(ctx, nil, continueArgs) if err == nil { t.Error("Expected error for non-existent session") } @@ -506,12 +461,7 @@ func TestInvalidOperations(t *testing.T) { SessionID: "nonexistent", } - reviewParams := &mcp.CallToolParamsFor[ReviewThinkingArgs]{ - Name: "review_thinking", - Arguments: reviewArgs, - } - - _, err = ReviewThinking(ctx, requestFor(reviewParams)) + _, _, err = ReviewThinking(ctx, nil, reviewArgs) if err == nil { t.Error("Expected error for non-existent session in review") } @@ -536,12 +486,7 @@ func TestInvalidOperations(t *testing.T) { ReviseStep: &reviseStep, } - invalidReviseParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ - Name: "continue_thinking", - Arguments: invalidReviseArgs, - } - - _, err = ContinueThinking(ctx, requestFor(invalidReviseParams)) + _, _, err = ContinueThinking(ctx, nil, invalidReviseArgs) if err == nil { t.Error("Expected error for invalid revision step") } diff --git a/examples/server/sse/main.go b/examples/server/sse/main.go index 2fbd695e..c2603b41 100644 --- a/examples/server/sse/main.go +++ b/examples/server/sse/main.go @@ -24,12 +24,12 @@ type SayHiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[SayHiParams]]) (*mcp.CallToolResultFor[any], error) { - return &mcp.CallToolResultFor[any]{ +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args SayHiParams) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}, + &mcp.TextContent{Text: "Hi " + args.Name}, }, - }, nil + }, nil, nil } func main() { diff --git a/internal/readme/server/server.go b/internal/readme/server/server.go index 3aa1037c..087992e8 100644 --- a/internal/readme/server/server.go +++ b/internal/readme/server/server.go @@ -16,10 +16,10 @@ type HiParams struct { Name string `json:"name" jsonschema:"the name of the person to greet"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[HiParams]]) (*mcp.CallToolResultFor[any], error) { - return &mcp.CallToolResultFor[any]{ - Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}}, - }, nil +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args HiParams) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + args.Name}}, + }, nil, nil } func main() { diff --git a/mcp/client_list_test.go b/mcp/client_list_test.go index 836d4803..c1052c25 100644 --- a/mcp/client_list_test.go +++ b/mcp/client_list_test.go @@ -24,9 +24,14 @@ func TestList(t *testing.T) { t.Run("tools", func(t *testing.T) { var wantTools []*mcp.Tool for _, name := range []string{"apple", "banana", "cherry"} { - t := &mcp.Tool{Name: name, Description: name + " tool"} - wantTools = append(wantTools, t) - mcp.AddTool(server, t, SayHi) + tt := &mcp.Tool{Name: name, Description: name + " tool"} + mcp.AddTool(server, tt, SayHi) + is, err := jsonschema.For[SayHiParams](nil) + if err != nil { + t.Fatal(err) + } + tt.InputSchema = is + wantTools = append(wantTools, tt) } t.Run("list", func(t *testing.T) { res, err := clientSession.ListTools(ctx, nil) diff --git a/mcp/content_nil_test.go b/mcp/content_nil_test.go index c803ba69..70cabfd7 100644 --- a/mcp/content_nil_test.go +++ b/mcp/content_nil_test.go @@ -52,8 +52,8 @@ func TestContentUnmarshalNil(t *testing.T) { { name: "CallToolResultFor nil Content", json: `{"content":[{"type":"text","text":"hello"}]}`, - content: &mcp.CallToolResultFor[string]{}, - want: &mcp.CallToolResultFor[string]{Content: []mcp.Content{&mcp.TextContent{Text: "hello"}}}, + content: &mcp.CallToolResult{}, + want: &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "hello"}}}, }, } diff --git a/mcp/example_middleware_test.go b/mcp/example_middleware_test.go index 56f7428a..0f6d540e 100644 --- a/mcp/example_middleware_test.go +++ b/mcp/example_middleware_test.go @@ -72,7 +72,7 @@ func Example_loggingMiddleware() { server.AddReceivingMiddleware(loggingMiddleware) // Add a simple tool - server.AddTool( + mcp.AddTool(server, &mcp.Tool{ Name: "greet", Description: "Greet someone with logging.", @@ -89,19 +89,19 @@ func Example_loggingMiddleware() { }, func( ctx context.Context, - req *mcp.ServerRequest[*mcp.CallToolParamsFor[map[string]any]], - ) (*mcp.CallToolResultFor[any], error) { - name, ok := req.Params.Arguments["name"].(string) + req *mcp.ServerRequest[*mcp.CallToolParams], args map[string]any, + ) (*mcp.CallToolResult, any, error) { + name, ok := args["name"].(string) if !ok { - return nil, fmt.Errorf("name parameter is required and must be a string") + return nil, nil, fmt.Errorf("name parameter is required and must be a string") } message := fmt.Sprintf("Hello, %s!", name) - return &mcp.CallToolResultFor[any]{ + return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: message}, }, - }, nil + }, nil, nil }, ) diff --git a/mcp/features_test.go b/mcp/features_test.go index 1c22ecd3..6df9b16e 100644 --- a/mcp/features_test.go +++ b/mcp/features_test.go @@ -5,7 +5,6 @@ package mcp import ( - "context" "slices" "testing" @@ -18,14 +17,6 @@ type SayHiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[SayHiParams]) (*CallToolResultFor[any], error) { - return &CallToolResultFor[any]{ - Content: []Content{ - &TextContent{Text: "Hi " + params.Name}, - }, - }, nil -} - func TestFeatureSetOrder(t *testing.T) { toolA := &Tool{Name: "apple", Description: "apple tool"} toolB := &Tool{Name: "banana", Description: "banana tool"} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 159f878f..44dd76d2 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -33,11 +33,11 @@ type hiParams struct { // TODO(jba): after schemas are stateless (WIP), this can be a variable. func greetTool() *Tool { return &Tool{Name: "greet", Description: "say hi"} } -func sayHi(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResultFor[any], error) { +func sayHi(ctx context.Context, req *ServerRequest[*CallToolParams], args hiParams) (*CallToolResult, any, error) { if err := req.Session.Ping(ctx, nil); err != nil { - return nil, fmt.Errorf("ping failed: %v", err) + return nil, nil, fmt.Errorf("ping failed: %v", err) } - return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil + return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + args.Name}}}, nil, nil } var codeReviewPrompt = &Prompt{ @@ -97,9 +97,9 @@ func TestEndToEnd(t *testing.T) { Name: "greet", Description: "say hi", }, sayHi) - s.AddTool(&Tool{Name: "fail", InputSchema: &jsonschema.Schema{}}, - func(context.Context, *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { - return nil, errTestFailure + AddTool(s, &Tool{Name: "fail", InputSchema: &jsonschema.Schema{}}, + func(context.Context, *ServerRequest[*CallToolParams], map[string]any) (*CallToolResult, any, error) { + return nil, nil, errTestFailure }) s.AddPrompt(codeReviewPrompt, codReviewPromptHandler) s.AddPrompt(&Prompt{Name: "fail"}, func(_ context.Context, _ *ServerSession, _ *GetPromptParams) (*GetPromptResult, error) { @@ -663,18 +663,18 @@ func TestCancellation(t *testing.T) { start = make(chan struct{}) cancelled = make(chan struct{}, 1) // don't block the request ) - slowRequest := func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + slowRequest := func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { start <- struct{}{} select { case <-ctx.Done(): cancelled <- struct{}{} case <-time.After(5 * time.Second): - return nil, nil + return nil, nil, nil } - return nil, nil + return nil, nil, nil } cs, _ := basicConnection(t, func(s *Server) { - AddTool(s, &Tool{Name: "slow"}, slowRequest) + AddTool(s, &Tool{Name: "slow", InputSchema: &jsonschema.Schema{}}, slowRequest) }) defer cs.Close() @@ -852,7 +852,7 @@ func traceCalls[S Session](w io.Writer, prefix string) Middleware { } } -func nopHandler(context.Context, *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { +func nopHandler(context.Context, *ServerRequest[*CallToolParams]) (*CallToolResult, error) { return nil, nil } @@ -1015,11 +1015,11 @@ func TestSynchronousNotifications(t *testing.T) { } server := NewServer(testImpl, serverOpts) cs, ss := basicClientServerConnection(t, client, server, func(s *Server) { - AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { if !rootsChanged.Load() { - return nil, fmt.Errorf("didn't get root change notification") + return nil, nil, fmt.Errorf("didn't get root change notification") } - return new(CallToolResult), nil + return new(CallToolResult), nil, nil }) }) @@ -1064,13 +1064,13 @@ func TestNoDistributedDeadlock(t *testing.T) { } client := NewClient(testImpl, clientOpts) cs, _ := basicClientServerConnection(t, client, nil, func(s *Server) { - AddTool(s, &Tool{Name: "tool1"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + AddTool(s, &Tool{Name: "tool1"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { req.Session.CreateMessage(ctx, new(CreateMessageParams)) - return new(CallToolResult), nil + return new(CallToolResult), nil, nil }) - AddTool(s, &Tool{Name: "tool2"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + AddTool(s, &Tool{Name: "tool2"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { req.Session.Ping(ctx, nil) - return new(CallToolResult), nil + return new(CallToolResult), nil, nil }) }) defer cs.Close() diff --git a/mcp/protocol.go b/mcp/protocol.go index 666c1bc7..75db7613 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -40,20 +40,32 @@ type Annotations struct { Priority float64 `json:"priority,omitempty"` } -type CallToolParams = CallToolParamsFor[any] - -type CallToolParamsFor[In any] struct { +type CallToolParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` Name string `json:"name"` - Arguments In `json:"arguments,omitempty"` + Arguments any `json:"arguments,omitempty"` } -// The server's response to a tool call. -type CallToolResult = CallToolResultFor[any] +// When unmarshalling CallToolParams on the server side, we need to delay unmarshaling of the arguments. +func (c *CallToolParams) UnmarshalJSON(data []byte) error { + var raw struct { + Meta `json:"_meta,omitempty"` + Name string `json:"name"` + RawArguments json.RawMessage `json:"arguments,omitempty"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + c.Meta = raw.Meta + c.Name = raw.Name + c.Arguments = raw.RawArguments + return nil +} -type CallToolResultFor[Out any] struct { +// The server's response to a tool call. +type CallToolResult struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` @@ -62,7 +74,7 @@ type CallToolResultFor[Out any] struct { Content []Content `json:"content"` // An optional JSON object that represents the structured result of the tool // call. - StructuredContent Out `json:"structuredContent,omitempty"` + StructuredContent any `json:"structuredContent,omitempty"` // Whether the tool call ended in an error. // // If not set, this is assumed to be false (the call was successful). @@ -78,12 +90,12 @@ type CallToolResultFor[Out any] struct { IsError bool `json:"isError,omitempty"` } -func (*CallToolResultFor[Out]) isResult() {} +func (*CallToolResult) isResult() {} // UnmarshalJSON handles the unmarshalling of content into the Content // interface. -func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error { - type res CallToolResultFor[Out] // avoid recursion +func (x *CallToolResult) UnmarshalJSON(data []byte) error { + type res CallToolResult // avoid recursion var wire struct { res Content []*wireContent `json:"content"` @@ -95,13 +107,13 @@ func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error { if wire.res.Content, err = contentsFromWire(wire.Content, nil); err != nil { return err } - *x = CallToolResultFor[Out](wire.res) + *x = CallToolResult(wire.res) return nil } -func (x *CallToolParamsFor[Out]) isParams() {} -func (x *CallToolParamsFor[Out]) GetProgressToken() any { return getProgressToken(x) } -func (x *CallToolParamsFor[Out]) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *CallToolParams) isParams() {} +func (x *CallToolParams) GetProgressToken() any { return getProgressToken(x) } +func (x *CallToolParams) SetProgressToken(t any) { setProgressToken(x, t) } type CancelledParams struct { // This property is reserved by the protocol to allow clients and servers to diff --git a/mcp/protocol_test.go b/mcp/protocol_test.go index dba80a8b..28e97518 100644 --- a/mcp/protocol_test.go +++ b/mcp/protocol_test.go @@ -208,6 +208,7 @@ func TestCompleteReference(t *testing.T) { }) } } + func TestCompleteParams(t *testing.T) { // Define test cases specifically for Marshalling marshalTests := []struct { @@ -514,13 +515,13 @@ func TestContentUnmarshal(t *testing.T) { var got CallToolResult roundtrip(ctr, &got) - ctrf := &CallToolResultFor[int]{ + ctrf := &CallToolResult{ Meta: Meta{"m": true}, Content: content, IsError: true, - StructuredContent: 3, + StructuredContent: 3.0, } - var gotf CallToolResultFor[int] + var gotf CallToolResult roundtrip(ctrf, &gotf) pm := &PromptMessage{ diff --git a/mcp/server.go b/mcp/server.go index 71257ccf..b8e72907 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -15,10 +15,12 @@ import ( "maps" "net/url" "path/filepath" + "reflect" "slices" "sync" "time" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/util" "github.com/modelcontextprotocol/go-sdk/jsonrpc" @@ -146,53 +148,128 @@ func (s *Server) RemovePrompts(names ...string) { // The tool's input schema must be non-nil. For a tool that takes no input, // or one where any input is valid, set [Tool.InputSchema] to the empty schema, // &jsonschema.Schema{}. +// +// When the handler is invoked as part of a CallTool request, req.Params.Arguments +// will be a json.RawMessage. Unmarshaling the arguments and validating them against the +// input schema are the handler author's responsibility. +// +// Most users will prefer the top-level function [AddTool]. func (s *Server) AddTool(t *Tool, h ToolHandler) { if t.InputSchema == nil { // This prevents the tool author from forgetting to write a schema where // one should be provided. If we papered over this by supplying the empty // schema, then every input would be validated and the problem wouldn't be // discovered until runtime, when the LLM sent bad data. - panic(fmt.Sprintf("adding tool %q: nil input schema", t.Name)) - } - if err := addToolErr(s, t, h); err != nil { - panic(err) + panic(fmt.Errorf("AddTool %q: missing input schema", t.Name)) } + st := &serverTool{tool: t, handler: h} + // Assume there was a change, since add replaces existing tools. + // (It's possible a tool was replaced with an identical one, but not worth checking.) + // TODO: Batch these changes by size and time? The typescript SDK doesn't. + // TODO: Surface notify error here? best not, in case we need to batch. + s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{}, + func() bool { s.tools.add(st); return true }) } -// AddTool adds a [Tool] to the server, or replaces one with the same name. +// toolFor returns a shallow copy of t and a [ToolHandler] that wraps h. // If the tool's input schema is nil, it is set to the schema inferred from the In // type parameter, using [jsonschema.For]. // If the tool's output schema is nil and the Out type parameter is not the empty // interface, then the output schema is set to the schema inferred from Out. -// The Tool argument must not be modified after this call. -func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { - if err := addToolErr(s, t, h); err != nil { - panic(err) +// +// Most users will call [AddTool]. Use [toolFor] if you wish to wrap the ToolHandler +// before calling [Server.AddTool]. +func toolFor[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler) { + tt, hh, err := toolForErr(t, h) + if err != nil { + panic(fmt.Sprintf("ToolFor: tool %q: %v", t.Name, err)) } + return tt, hh } -func addToolErr[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) (err error) { - defer util.Wrapf(&err, "adding tool %q", t.Name) - // If the exact same Tool pointer has already been registered under this name, - // avoid rebuilding schemas and re-registering. This prevents duplicate - // registration from causing errors (and unnecessary work). - s.mu.Lock() - if existing, ok := s.tools.get(t.Name); ok && existing.tool == t { - s.mu.Unlock() - return nil +// TODO(v0.3.0): test +func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { + var err error + tt := *t + tt.InputSchema = t.InputSchema + if tt.InputSchema == nil { + tt.InputSchema, err = jsonschema.For[In](nil) + if err != nil { + return nil, nil, fmt.Errorf("input schema: %w", err) + } } - s.mu.Unlock() - st, err := newServerTool(t, h) + inputResolved, err := tt.InputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) if err != nil { - return err + return nil, nil, fmt.Errorf("resolving input schema: %w", err) } - // Assume there was a change, since add replaces existing tools. - // (It's possible a tool was replaced with an identical one, but not worth checking.) - // TODO: Batch these changes by size and time? The typescript SDK doesn't. - // TODO: Surface notify error here? best not, in case we need to batch. - s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{}, - func() bool { s.tools.add(st); return true }) - return nil + + if tt.OutputSchema == nil && reflect.TypeFor[Out]() != reflect.TypeFor[any]() { + tt.OutputSchema, err = jsonschema.For[Out](nil) + } + if err != nil { + return nil, nil, fmt.Errorf("output schema: %w", err) + } + var outputResolved *jsonschema.Resolved + if tt.OutputSchema != nil { + outputResolved, err = tt.OutputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) + if err != nil { + return nil, nil, fmt.Errorf("resolving output schema: %w", err) + } + } + + th := func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + // Unmarshal and validate args. + rawArgs := req.Params.Arguments.(json.RawMessage) + var in In + if rawArgs != nil { + if err := unmarshalSchema(rawArgs, inputResolved, &in); err != nil { + return nil, err + } + } + + // Call typed handler. + res, out, err := h(ctx, req, in) + // Handle server errors appropriately: + // - If the handler returns a structured error (like jsonrpc2.WireError), return it directly + // - If the handler returns a regular error, wrap it in a CallToolResult with IsError=true + // - This allows tools to distinguish between protocol errors and tool execution errors + if err != nil { + // Check if this is already a structured JSON-RPC error + if wireErr, ok := err.(*jsonrpc2.WireError); ok { + return nil, wireErr + } + // For regular errors, embed them in the tool result as per MCP spec + return &CallToolResult{ + Content: []Content{&TextContent{Text: err.Error()}}, + IsError: true, + }, nil + } + + // TODO(v0.3.0): Validate out. + _ = outputResolved + + // TODO: return the serialized JSON in a TextContent block, as per spec? + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content + // But people may use res.Content for other things. + if res == nil { + res = &CallToolResult{} + } + res.StructuredContent = out + return res, nil + } + + return &tt, th, nil +} + +// AddTool adds a tool and handler to the server. +// +// A shallow copy of the tool is made first. +// If the tool's input schema is nil, the copy's input schema is set to the schema +// inferred from the In type parameter, using [jsonschema.For]. +// If the tool's output schema is nil and the Out type parameter is not the empty +// interface, then the copy's output schema is set to the schema inferred from Out. +func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { + s.AddTool(toolFor(t, h)) } // RemoveTools removes the tools with the given names. @@ -352,7 +429,7 @@ func (s *Server) listTools(_ context.Context, req *ServerRequest[*ListToolsParam }) } -func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParamsFor[json.RawMessage]]) (*CallToolResult, error) { +func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { s.mu.Lock() st, ok := s.tools.get(req.Params.Name) s.mu.Unlock() diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index f735b84e..2b4a0bf1 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -16,12 +16,12 @@ type SayHiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[SayHiParams]]) (*mcp.CallToolResultFor[any], error) { - return &mcp.CallToolResultFor[any]{ +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args SayHiParams) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}, + &mcp.TextContent{Text: "Hi " + args.Name}, }, - }, nil + }, nil, nil } func ExampleServer() { diff --git a/mcp/shared_test.go b/mcp/shared_test.go index 01d1eff7..4d0859ac 100644 --- a/mcp/shared_test.go +++ b/mcp/shared_test.go @@ -4,232 +4,222 @@ package mcp -import ( - "context" - "encoding/json" - "fmt" - "strings" - "testing" -) - -// TODO(jba): this shouldn't be in this file, but tool_test.go doesn't have access to unexported symbols. -func TestToolValidate(t *testing.T) { - // Check that the tool returned from NewServerTool properly validates its input schema. - - type req struct { - I int - B bool - S string `json:",omitempty"` - P *int `json:",omitempty"` - } - - dummyHandler := func(context.Context, *ServerRequest[*CallToolParamsFor[req]]) (*CallToolResultFor[any], error) { - return nil, nil - } - - st, err := newServerTool(&Tool{Name: "test", Description: "test"}, dummyHandler) - if err != nil { - t.Fatal(err) - } - - for _, tt := range []struct { - desc string - args map[string]any - want string // error should contain this string; empty for success - }{ - { - "both required", - map[string]any{"I": 1, "B": true}, - "", - }, - { - "optional", - map[string]any{"I": 1, "B": true, "S": "foo"}, - "", - }, - { - "wrong type", - map[string]any{"I": 1.5, "B": true}, - "cannot unmarshal", - }, - { - "extra property", - map[string]any{"I": 1, "B": true, "C": 2}, - "unknown field", - }, - { - "value for pointer", - map[string]any{"I": 1, "B": true, "P": 3}, - "", - }, - { - "null for pointer", - map[string]any{"I": 1, "B": true, "P": nil}, - "", - }, - } { - t.Run(tt.desc, func(t *testing.T) { - raw, err := json.Marshal(tt.args) - if err != nil { - t.Fatal(err) - } - _, err = st.handler(context.Background(), &ServerRequest[*CallToolParamsFor[json.RawMessage]]{ - Params: &CallToolParamsFor[json.RawMessage]{Arguments: json.RawMessage(raw)}, - }) - if err == nil && tt.want != "" { - t.Error("got success, wanted failure") - } - if err != nil { - if tt.want == "" { - t.Fatalf("failed with:\n%s\nwanted success", err) - } - if !strings.Contains(err.Error(), tt.want) { - t.Fatalf("got:\n%s\nwanted to contain %q", err, tt.want) - } - } - }) - } -} +// TODO(v0.3.0): rewrite this test. +// func TestToolValidate(t *testing.T) { +// // Check that the tool returned from NewServerTool properly validates its input schema. + +// type req struct { +// I int +// B bool +// S string `json:",omitempty"` +// P *int `json:",omitempty"` +// } + +// dummyHandler := func(context.Context, *ServerRequest[*CallToolParams], req) (*CallToolResultFor[any], error) { +// return nil, nil +// } + +// st, err := newServerTool(&Tool{Name: "test", Description: "test"}, dummyHandler) +// if err != nil { +// t.Fatal(err) +// } + +// for _, tt := range []struct { +// desc string +// args map[string]any +// want string // error should contain this string; empty for success +// }{ +// { +// "both required", +// map[string]any{"I": 1, "B": true}, +// "", +// }, +// { +// "optional", +// map[string]any{"I": 1, "B": true, "S": "foo"}, +// "", +// }, +// { +// "wrong type", +// map[string]any{"I": 1.5, "B": true}, +// "cannot unmarshal", +// }, +// { +// "extra property", +// map[string]any{"I": 1, "B": true, "C": 2}, +// "unknown field", +// }, +// { +// "value for pointer", +// map[string]any{"I": 1, "B": true, "P": 3}, +// "", +// }, +// { +// "null for pointer", +// map[string]any{"I": 1, "B": true, "P": nil}, +// "", +// }, +// } { +// t.Run(tt.desc, func(t *testing.T) { +// raw, err := json.Marshal(tt.args) +// if err != nil { +// t.Fatal(err) +// } +// _, err = st.handler(context.Background(), &ServerRequest[*CallToolParamsFor[json.RawMessage]]{ +// Params: &CallToolParamsFor[json.RawMessage]{Arguments: json.RawMessage(raw)}, +// }) +// if err == nil && tt.want != "" { +// t.Error("got success, wanted failure") +// } +// if err != nil { +// if tt.want == "" { +// t.Fatalf("failed with:\n%s\nwanted success", err) +// } +// if !strings.Contains(err.Error(), tt.want) { +// t.Fatalf("got:\n%s\nwanted to contain %q", err, tt.want) +// } +// } +// }) +// } +// } // TestNilParamsHandling tests that nil parameters don't cause panic in unmarshalParams. // This addresses a vulnerability where missing or null parameters could crash the server. -func TestNilParamsHandling(t *testing.T) { - // Define test types for clarity - type TestArgs struct { - Name string `json:"name"` - Value int `json:"value"` - } - type TestParams = *CallToolParamsFor[TestArgs] - type TestResult = *CallToolResultFor[string] - - // Simple test handler - testHandler := func(ctx context.Context, req *ServerRequest[TestParams]) (TestResult, error) { - result := "processed: " + req.Params.Arguments.Name - return &CallToolResultFor[string]{StructuredContent: result}, nil - } - - methodInfo := newServerMethodInfo(testHandler, missingParamsOK) - - // Helper function to test that unmarshalParams doesn't panic and handles nil gracefully - mustNotPanic := func(t *testing.T, rawMsg json.RawMessage, expectNil bool) Params { - t.Helper() - - defer func() { - if r := recover(); r != nil { - t.Fatalf("unmarshalParams panicked: %v", r) - } - }() - - params, err := methodInfo.unmarshalParams(rawMsg) - if err != nil { - t.Fatalf("unmarshalParams failed: %v", err) - } - - if expectNil { - if params != nil { - t.Fatalf("Expected nil params, got %v", params) - } - return params - } - - if params == nil { - t.Fatal("unmarshalParams returned unexpected nil") - } - - // Verify the result can be used safely - typedParams := params.(TestParams) - _ = typedParams.Name - _ = typedParams.Arguments.Name - _ = typedParams.Arguments.Value - - return params - } - - // Test different nil parameter scenarios - with missingParamsOK flag, nil/null should return nil - t.Run("missing_params", func(t *testing.T) { - mustNotPanic(t, nil, true) // Expect nil with missingParamsOK flag - }) - - t.Run("explicit_null", func(t *testing.T) { - mustNotPanic(t, json.RawMessage(`null`), true) // Expect nil with missingParamsOK flag - }) - - t.Run("empty_object", func(t *testing.T) { - mustNotPanic(t, json.RawMessage(`{}`), false) // Empty object should create valid params - }) - - t.Run("valid_params", func(t *testing.T) { - rawMsg := json.RawMessage(`{"name":"test","arguments":{"name":"hello","value":42}}`) - params := mustNotPanic(t, rawMsg, false) - - // For valid params, also verify the values are parsed correctly - typedParams := params.(TestParams) - if typedParams.Name != "test" { - t.Errorf("Expected name 'test', got %q", typedParams.Name) - } - if typedParams.Arguments.Name != "hello" { - t.Errorf("Expected argument name 'hello', got %q", typedParams.Arguments.Name) - } - if typedParams.Arguments.Value != 42 { - t.Errorf("Expected argument value 42, got %d", typedParams.Arguments.Value) - } - }) -} +// func TestNilParamsHandling(t *testing.T) { +// // Define test types for clarity +// type TestArgs struct { +// Name string `json:"name"` +// Value int `json:"value"` +// } + +// // Simple test handler +// testHandler := func(ctx context.Context, req *ServerRequest[**GetPromptParams]) (*GetPromptResult, error) { +// result := "processed: " + req.Params.Arguments.Name +// return &CallToolResultFor[string]{StructuredContent: result}, nil +// } + +// methodInfo := newServerMethodInfo(testHandler, missingParamsOK) + +// // Helper function to test that unmarshalParams doesn't panic and handles nil gracefully +// mustNotPanic := func(t *testing.T, rawMsg json.RawMessage, expectNil bool) Params { +// t.Helper() + +// defer func() { +// if r := recover(); r != nil { +// t.Fatalf("unmarshalParams panicked: %v", r) +// } +// }() + +// params, err := methodInfo.unmarshalParams(rawMsg) +// if err != nil { +// t.Fatalf("unmarshalParams failed: %v", err) +// } + +// if expectNil { +// if params != nil { +// t.Fatalf("Expected nil params, got %v", params) +// } +// return params +// } + +// if params == nil { +// t.Fatal("unmarshalParams returned unexpected nil") +// } + +// // Verify the result can be used safely +// typedParams := params.(TestParams) +// _ = typedParams.Name +// _ = typedParams.Arguments.Name +// _ = typedParams.Arguments.Value + +// return params +// } + +// // Test different nil parameter scenarios - with missingParamsOK flag, nil/null should return nil +// t.Run("missing_params", func(t *testing.T) { +// mustNotPanic(t, nil, true) // Expect nil with missingParamsOK flag +// }) + +// t.Run("explicit_null", func(t *testing.T) { +// mustNotPanic(t, json.RawMessage(`null`), true) // Expect nil with missingParamsOK flag +// }) + +// t.Run("empty_object", func(t *testing.T) { +// mustNotPanic(t, json.RawMessage(`{}`), false) // Empty object should create valid params +// }) + +// t.Run("valid_params", func(t *testing.T) { +// rawMsg := json.RawMessage(`{"name":"test","arguments":{"name":"hello","value":42}}`) +// params := mustNotPanic(t, rawMsg, false) + +// // For valid params, also verify the values are parsed correctly +// typedParams := params.(TestParams) +// if typedParams.Name != "test" { +// t.Errorf("Expected name 'test', got %q", typedParams.Name) +// } +// if typedParams.Arguments.Name != "hello" { +// t.Errorf("Expected argument name 'hello', got %q", typedParams.Arguments.Name) +// } +// if typedParams.Arguments.Value != 42 { +// t.Errorf("Expected argument value 42, got %d", typedParams.Arguments.Value) +// } +// }) +// } // TestNilParamsEdgeCases tests edge cases to ensure we don't over-fix -func TestNilParamsEdgeCases(t *testing.T) { - type TestArgs struct { - Name string `json:"name"` - Value int `json:"value"` - } - type TestParams = *CallToolParamsFor[TestArgs] - - testHandler := func(context.Context, *ServerRequest[TestParams]) (*CallToolResultFor[string], error) { - return &CallToolResultFor[string]{StructuredContent: "test"}, nil - } - - methodInfo := newServerMethodInfo(testHandler, missingParamsOK) - - // These should fail normally, not be treated as nil params - invalidCases := []json.RawMessage{ - json.RawMessage(""), // empty string - should error - json.RawMessage("[]"), // array - should error - json.RawMessage(`"null"`), // string "null" - should error - json.RawMessage("0"), // number - should error - json.RawMessage("false"), // boolean - should error - } - - for i, rawMsg := range invalidCases { - t.Run(fmt.Sprintf("invalid_case_%d", i), func(t *testing.T) { - params, err := methodInfo.unmarshalParams(rawMsg) - if err == nil && params == nil { - t.Error("Should not return nil params without error") - } - }) - } - - // Test that methods without missingParamsOK flag properly reject nil params - t.Run("reject_when_params_required", func(t *testing.T) { - methodInfoStrict := newServerMethodInfo(testHandler, 0) // No missingParamsOK flag - - testCases := []struct { - name string - params json.RawMessage - }{ - {"nil_params", nil}, - {"null_params", json.RawMessage(`null`)}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - _, err := methodInfoStrict.unmarshalParams(tc.params) - if err == nil { - t.Error("Expected error for required params, got nil") - } - if !strings.Contains(err.Error(), "missing required \"params\"") { - t.Errorf("Expected 'missing required params' error, got: %v", err) - } - }) - } - }) -} +// func TestNilParamsEdgeCases(t *testing.T) { +// type TestArgs struct { +// Name string `json:"name"` +// Value int `json:"value"` +// } +// type TestParams = *CallToolParamsFor[TestArgs] + +// testHandler := func(context.Context, *ServerRequest[TestParams]) (*CallToolResultFor[string], error) { +// return &CallToolResultFor[string]{StructuredContent: "test"}, nil +// } + +// methodInfo := newServerMethodInfo(testHandler, missingParamsOK) + +// // These should fail normally, not be treated as nil params +// invalidCases := []json.RawMessage{ +// json.RawMessage(""), // empty string - should error +// json.RawMessage("[]"), // array - should error +// json.RawMessage(`"null"`), // string "null" - should error +// json.RawMessage("0"), // number - should error +// json.RawMessage("false"), // boolean - should error +// } + +// for i, rawMsg := range invalidCases { +// t.Run(fmt.Sprintf("invalid_case_%d", i), func(t *testing.T) { +// params, err := methodInfo.unmarshalParams(rawMsg) +// if err == nil && params == nil { +// t.Error("Should not return nil params without error") +// } +// }) +// } + +// // Test that methods without missingParamsOK flag properly reject nil params +// t.Run("reject_when_params_required", func(t *testing.T) { +// methodInfoStrict := newServerMethodInfo(testHandler, 0) // No missingParamsOK flag + +// testCases := []struct { +// name string +// params json.RawMessage +// }{ +// {"nil_params", nil}, +// {"null_params", json.RawMessage(`null`)}, +// } + +// for _, tc := range testCases { +// t.Run(tc.name, func(t *testing.T) { +// _, err := methodInfoStrict.unmarshalParams(tc.params) +// if err == nil { +// t.Error("Expected error for required params, got nil") +// } +// if !strings.Contains(err.Error(), "missing required \"params\"") { +// t.Errorf("Expected 'missing required params' error, got: %v", err) +// } +// }) +// } +// }) +// } diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index b5dfdc56..93ccf788 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -18,12 +18,12 @@ type AddParams struct { X, Y int } -func Add(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[AddParams]]) (*mcp.CallToolResultFor[any], error) { - return &mcp.CallToolResultFor[any]{ +func Add(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args AddParams) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("%d", req.Params.Arguments.X+req.Params.Arguments.Y)}, + &mcp.TextContent{Text: fmt.Sprintf("%d", args.X+args.Y)}, }, - }, nil + }, nil, nil } func ExampleSSEHandler() { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index b050b35d..28246fd3 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -49,18 +49,18 @@ func TestStreamableTransports(t *testing.T) { start = make(chan struct{}) cancelled = make(chan struct{}, 1) // don't block the request ) - hang := func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + hang := func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { start <- struct{}{} select { case <-ctx.Done(): cancelled <- struct{}{} case <-time.After(5 * time.Second): - return nil, nil + return nil, nil, nil } - return nil, nil + return nil, nil, nil } AddTool(server, &Tool{Name: "hang"}, hang) - AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args map[string]any) (*CallToolResult, any, error) { // Test that we can make sampling requests during tool handling. // // Try this on both the request context and a background context, so @@ -71,13 +71,13 @@ func TestStreamableTransports(t *testing.T) { } { res, err := req.Session.CreateMessage(ctx, &CreateMessageParams{}) if err != nil { - return nil, err + return nil, nil, err } if g, w := res.Model, "aModel"; g != w { - return nil, fmt.Errorf("got %q, want %q", g, w) + return nil, nil, fmt.Errorf("got %q, want %q", g, w) } } - return &CallToolResultFor[any]{}, nil + return &CallToolResult{}, nil, nil }) // Start an httptest.Server with the StreamableHTTPHandler, wrapped in a @@ -219,8 +219,8 @@ func testClientReplay(t *testing.T, test clientReplayTest) { // proxy-killing action. serverReadyToKillProxy := make(chan struct{}) serverClosed := make(chan struct{}) - server.AddTool(&Tool{Name: "multiMessageTool", InputSchema: &jsonschema.Schema{}}, - func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { + AddTool(server, &Tool{Name: "multiMessageTool", InputSchema: &jsonschema.Schema{}}, + func(ctx context.Context, req *ServerRequest[*CallToolParams], args map[string]any) (*CallToolResult, any, error) { // Send one message to the request context, and another to a background // context (which will end up on the hanging GET). @@ -236,7 +236,7 @@ func testClientReplay(t *testing.T, test clientReplayTest) { // the client's connection drops. req.Session.NotifyProgress(ctx, &ProgressNotificationParams{Message: "msg3"}) req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"}) - return new(CallToolResult), nil + return new(CallToolResult), nil, nil }) realServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) @@ -353,9 +353,9 @@ func TestServerInitiatedSSE(t *testing.T) { t.Fatalf("client.Connect() failed: %v", err) } defer clientSession.Close() - server.AddTool(&Tool{Name: "testTool", InputSchema: &jsonschema.Schema{}}, - func(context.Context, *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { - return &CallToolResult{}, nil + AddTool(server, &Tool{Name: "testTool", InputSchema: &jsonschema.Schema{}}, + func(context.Context, *ServerRequest[*CallToolParams], map[string]any) (*CallToolResult, any, error) { + return &CallToolResult{}, nil, nil }) receivedNotifications := readNotifications(t, ctx, notifications, 1) wantReceived := []string{"toolListChanged"} @@ -657,12 +657,14 @@ func TestStreamableServerTransport(t *testing.T) { // Create a server containing a single tool, which runs the test tool // behavior, if any. server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) - AddTool(server, &Tool{Name: "tool"}, func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[any]]) (*CallToolResultFor[any], error) { - if test.tool != nil { - test.tool(t, ctx, req.Session) - } - return &CallToolResultFor[any]{}, nil - }) + server.AddTool( + &Tool{Name: "tool", InputSchema: &jsonschema.Schema{}}, + func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + if test.tool != nil { + test.tool(t, ctx, req.Session) + } + return &CallToolResult{}, nil + }) // Start the streamable handler. handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) @@ -1070,12 +1072,12 @@ func TestEventID(t *testing.T) { func TestStreamableStateless(t *testing.T) { // This version of sayHi expects // that request from our client). - sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResult, error) { + sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParams], args hiParams) (*CallToolResult, any, error) { if err := req.Session.Ping(ctx, nil); err == nil { // ping should fail, but not break the connection t.Errorf("ping succeeded unexpectedly") } - return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil + return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + args.Name}}}, nil, nil } server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) @@ -1177,8 +1179,8 @@ func TestTokenInfo(t *testing.T) { ctx := context.Background() // Create a server with a tool that returns TokenInfo. - tokenInfo := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[struct{}]]) (*CallToolResultFor[any], error) { - return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.Extra.TokenInfo)}}}, nil + tokenInfo := func(ctx context.Context, req *ServerRequest[*CallToolParams], _ struct{}) (*CallToolResult, any, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.Extra.TokenInfo)}}}, nil, nil } server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo) diff --git a/mcp/tool.go b/mcp/tool.go index 893b48ff..f0178c23 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -9,109 +9,22 @@ import ( "context" "encoding/json" "fmt" - "reflect" "github.com/google/jsonschema-go/jsonschema" - "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) // A ToolHandler handles a call to tools/call. // [CallToolParams.Arguments] will contain a map[string]any that has been validated // against the input schema. -type ToolHandler = ToolHandlerFor[map[string]any, any] +type ToolHandler func(context.Context, *ServerRequest[*CallToolParams]) (*CallToolResult, error) // A ToolHandlerFor handles a call to tools/call with typed arguments and results. -type ToolHandlerFor[In, Out any] func(context.Context, *ServerRequest[*CallToolParamsFor[In]]) (*CallToolResultFor[Out], error) - -// A rawToolHandler is like a ToolHandler, but takes the arguments as as json.RawMessage. -// Second arg is *Request[*ServerSession, *CallToolParamsFor[json.RawMessage]], but that creates -// a cycle. -type rawToolHandler = func(context.Context, any) (*CallToolResult, error) +type ToolHandlerFor[In, Out any] func(context.Context, *ServerRequest[*CallToolParams], In) (*CallToolResult, Out, error) // A serverTool is a tool definition that is bound to a tool handler. type serverTool struct { tool *Tool - handler rawToolHandler - // Resolved tool schemas. Set in newServerTool. - inputResolved, outputResolved *jsonschema.Resolved -} - -// newServerTool creates a serverTool from a tool and a handler. -// If the tool doesn't have an input schema, it is inferred from In. -// If the tool doesn't have an output schema and Out != any, it is inferred from Out. -func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool, error) { - st := &serverTool{tool: t} - - if err := setSchema[In](&t.InputSchema, &st.inputResolved); err != nil { - return nil, err - } - if reflect.TypeFor[Out]() != reflect.TypeFor[any]() { - if err := setSchema[Out](&t.OutputSchema, &st.outputResolved); err != nil { - return nil, err - } - } - - st.handler = func(ctx context.Context, areq any) (*CallToolResult, error) { - req := areq.(*ServerRequest[*CallToolParamsFor[json.RawMessage]]) - var args In - if req.Params.Arguments != nil { - if err := unmarshalSchema(req.Params.Arguments, st.inputResolved, &args); err != nil { - return nil, err - } - } - // TODO(jba): future-proof this copy. - params := &CallToolParamsFor[In]{ - Meta: req.Params.Meta, - Name: req.Params.Name, - Arguments: args, - } - // TODO(jba): improve copy - res, err := h(ctx, &ServerRequest[*CallToolParamsFor[In]]{ - Session: req.Session, - Params: params, - Extra: req.Extra, - }) - // Handle server errors appropriately: - // - If the handler returns a structured error (like jsonrpc2.WireError), return it directly - // - If the handler returns a regular error, wrap it in a CallToolResult with IsError=true - // - This allows tools to distinguish between protocol errors and tool execution errors - if err != nil { - // Check if this is already a structured JSON-RPC error - if wireErr, ok := err.(*jsonrpc2.WireError); ok { - return nil, wireErr - } - // For regular errors, embed them in the tool result as per MCP spec - return &CallToolResult{ - Content: []Content{&TextContent{Text: err.Error()}}, - IsError: true, - }, nil - } - var ctr CallToolResult - // TODO(jba): What if res == nil? Is that valid? - // TODO(jba): if t.OutputSchema != nil, check that StructuredContent is present and validates. - if res != nil { - // TODO(jba): future-proof this copy. - ctr.Meta = res.Meta - ctr.Content = res.Content - ctr.IsError = res.IsError - ctr.StructuredContent = res.StructuredContent - } - return &ctr, nil - } - - return st, nil -} - -func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) error { - var err error - if *sfield == nil { - *sfield, err = jsonschema.For[T](nil) - } - if err != nil { - return err - } - *rfield, err = (*sfield).Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) - return err + handler ToolHandler } // unmarshalSchema unmarshals data into v and validates the result according to diff --git a/mcp/tool_test.go b/mcp/tool_test.go index 4c73ec63..756d6aa4 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -13,91 +13,10 @@ import ( "strings" "testing" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) -// testToolHandler is used for type inference in TestNewServerTool. -func testToolHandler[In, Out any](context.Context, *ServerRequest[*CallToolParamsFor[In]]) (*CallToolResultFor[Out], error) { - panic("not implemented") -} - -func srvTool[In, Out any](t *testing.T, tool *Tool, handler ToolHandlerFor[In, Out]) *serverTool { - t.Helper() - st, err := newServerTool(tool, handler) - if err != nil { - t.Fatal(err) - } - return st -} - -func TestNewServerTool(t *testing.T) { - type ( - Name struct { - Name string `json:"name"` - } - Size struct { - Size int `json:"size"` - } - ) - - nameSchema := &jsonschema.Schema{ - Type: "object", - Required: []string{"name"}, - Properties: map[string]*jsonschema.Schema{ - "name": {Type: "string"}, - }, - AdditionalProperties: &jsonschema.Schema{Not: new(jsonschema.Schema)}, - } - sizeSchema := &jsonschema.Schema{ - Type: "object", - Required: []string{"size"}, - Properties: map[string]*jsonschema.Schema{ - "size": {Type: "integer"}, - }, - AdditionalProperties: &jsonschema.Schema{Not: new(jsonschema.Schema)}, - } - - tests := []struct { - tool *serverTool - wantIn, wantOut *jsonschema.Schema - }{ - { - srvTool(t, &Tool{Name: "basic"}, testToolHandler[Name, Size]), - nameSchema, - sizeSchema, - }, - { - srvTool(t, &Tool{ - Name: "in untouched", - InputSchema: &jsonschema.Schema{}, - }, testToolHandler[Name, Size]), - &jsonschema.Schema{}, - sizeSchema, - }, - { - srvTool(t, &Tool{Name: "out untouched", OutputSchema: &jsonschema.Schema{}}, testToolHandler[Name, Size]), - nameSchema, - &jsonschema.Schema{}, - }, - { - srvTool(t, &Tool{Name: "nil out"}, testToolHandler[Name, any]), - nameSchema, - nil, - }, - } - for _, test := range tests { - if diff := cmp.Diff(test.wantIn, test.tool.tool.InputSchema, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Errorf("newServerTool(%q) input schema mismatch (-want +got):\n%s", test.tool.tool.Name, diff) - } - if diff := cmp.Diff(test.wantOut, test.tool.tool.OutputSchema, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Errorf("newServerTool(%q) output schema mismatch (-want +got):\n%s", test.tool.tool.Name, diff) - } - } -} - func TestUnmarshalSchema(t *testing.T) { schema := &jsonschema.Schema{ Type: "object", @@ -142,16 +61,16 @@ func TestToolErrorHandling(t *testing.T) { server := NewServer(testImpl, nil) // Create a tool that returns a structured error - structuredErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResultFor[any], error) { - return nil, &jsonrpc2.WireError{ + structuredErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParams], args map[string]any) (*CallToolResult, any, error) { + return nil, nil, &jsonrpc2.WireError{ Code: CodeInvalidParams, Message: "internal server error", } } // Create a tool that returns a regular error - regularErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResultFor[any], error) { - return nil, fmt.Errorf("tool execution failed") + regularErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParams], args map[string]any) (*CallToolResult, any, error) { + return nil, nil, fmt.Errorf("tool execution failed") } AddTool(server, &Tool{Name: "error_tool", Description: "returns structured error"}, structuredErrorHandler) @@ -201,7 +120,6 @@ func TestToolErrorHandling(t *testing.T) { Name: "regular_error_tool", Arguments: map[string]any{}, }) - // Should not get an error at the protocol level if err != nil { t.Fatalf("unexpected protocol error: %v", err) From 5b1f328fd4232dc8d6335dcb37bc6553291b8b6c Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 21 Aug 2025 10:18:52 -0400 Subject: [PATCH 121/221] mcp: uncomment test, change struct{} to any (#341) Address comments on #325. --- examples/server/memory/kb.go | 15 +- examples/server/memory/kb_test.go | 352 +++++++++++++----------------- mcp/streamable_test.go | 2 +- 3 files changed, 163 insertions(+), 206 deletions(-) diff --git a/examples/server/memory/kb.go b/examples/server/memory/kb.go index 2277c22b..e28a4909 100644 --- a/examples/server/memory/kb.go +++ b/examples/server/memory/kb.go @@ -482,34 +482,34 @@ func (k knowledgeBase) AddObservations(ctx context.Context, req *mcp.ServerReque }, nil } -func (k knowledgeBase) DeleteEntities(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args DeleteEntitiesArgs) (*mcp.CallToolResult, struct{}, error) { +func (k knowledgeBase) DeleteEntities(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args DeleteEntitiesArgs) (*mcp.CallToolResult, any, error) { var res mcp.CallToolResult err := k.deleteEntities(args.EntityNames) if err != nil { - return nil, struct{}{}, err + return nil, nil, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Entities deleted successfully"}, } - return &res, struct{}{}, nil + return &res, nil, nil } -func (k knowledgeBase) DeleteObservations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args DeleteObservationsArgs) (*mcp.CallToolResult, struct{}, error) { +func (k knowledgeBase) DeleteObservations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args DeleteObservationsArgs) (*mcp.CallToolResult, any, error) { var res mcp.CallToolResult err := k.deleteObservations(args.Deletions) if err != nil { - return nil, struct{}{}, err + return nil, nil, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Observations deleted successfully"}, } - return &res, struct{}{}, nil + return &res, nil, nil } func (k knowledgeBase) DeleteRelations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args DeleteRelationsArgs) (*mcp.CallToolResult, struct{}, error) { @@ -554,8 +554,7 @@ func (k knowledgeBase) SearchNodes(ctx context.Context, req *mcp.ServerRequest[* &mcp.TextContent{Text: "Nodes searched successfully"}, } - res.StructuredContent = graph - return &res, KnowledgeGraph{}, nil + return &res, graph, nil } func (k knowledgeBase) OpenNodes(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args OpenNodesArgs) (*mcp.CallToolResult, KnowledgeGraph, error) { diff --git a/examples/server/memory/kb_test.go b/examples/server/memory/kb_test.go index d0cf38c0..5d40ae64 100644 --- a/examples/server/memory/kb_test.go +++ b/examples/server/memory/kb_test.go @@ -427,203 +427,161 @@ func TestFileFormatting(t *testing.T) { } // TestMCPServerIntegration tests the knowledge base through MCP server layer. -// func TestMCPServerIntegration(t *testing.T) { -// for name, newStore := range stores() { -// t.Run(name, func(t *testing.T) { -// s := newStore(t) -// kb := knowledgeBase{s: s} - -// // Create mock server session -// ctx := context.Background() -// serverSession := &mcp.ServerSession{} - -// // Test CreateEntities through MCP -// createEntitiesParams := &mcp.CallToolParamsFor[CreateEntitiesArgs]{ -// Arguments: CreateEntitiesArgs{ -// Entities: []Entity{ -// { -// Name: "TestPerson", -// EntityType: "Person", -// Observations: []string{"Likes testing"}, -// }, -// }, -// }, -// } - -// createResult, err := kb.CreateEntities(ctx, requestFor(serverSession, createEntitiesParams)) -// if err != nil { -// t.Fatalf("MCP CreateEntities failed: %v", err) -// } -// if createResult.IsError { -// t.Fatalf("MCP CreateEntities returned error: %v", createResult.Content) -// } -// if len(createResult.StructuredContent.Entities) != 1 { -// t.Errorf("expected 1 entity created, got %d", len(createResult.StructuredContent.Entities)) -// } - -// // Test ReadGraph through MCP -// readParams := &mcp.CallToolParamsFor[struct{}]{} -// readResult, err := kb.ReadGraph(ctx, requestFor(serverSession, readParams)) -// if err != nil { -// t.Fatalf("MCP ReadGraph failed: %v", err) -// } -// if readResult.IsError { -// t.Fatalf("MCP ReadGraph returned error: %v", readResult.Content) -// } -// if len(readResult.StructuredContent.Entities) != 1 { -// t.Errorf("expected 1 entity in graph, got %d", len(readResult.StructuredContent.Entities)) -// } - -// // Test CreateRelations through MCP -// createRelationsParams := &mcp.CallToolParamsFor[CreateRelationsArgs]{ -// Arguments: CreateRelationsArgs{ -// Relations: []Relation{ -// { -// From: "TestPerson", -// To: "Testing", -// RelationType: "likes", -// }, -// }, -// }, -// } - -// relationsResult, err := kb.CreateRelations(ctx, requestFor(serverSession, createRelationsParams)) -// if err != nil { -// t.Fatalf("MCP CreateRelations failed: %v", err) -// } -// if relationsResult.IsError { -// t.Fatalf("MCP CreateRelations returned error: %v", relationsResult.Content) -// } -// if len(relationsResult.StructuredContent.Relations) != 1 { -// t.Errorf("expected 1 relation created, got %d", len(relationsResult.StructuredContent.Relations)) -// } - -// // Test AddObservations through MCP -// addObsParams := &mcp.CallToolParamsFor[AddObservationsArgs]{ -// Arguments: AddObservationsArgs{ -// Observations: []Observation{ -// { -// EntityName: "TestPerson", -// Contents: []string{"Works remotely", "Drinks coffee"}, -// }, -// }, -// }, -// } - -// obsResult, err := kb.AddObservations(ctx, requestFor(serverSession, addObsParams)) -// if err != nil { -// t.Fatalf("MCP AddObservations failed: %v", err) -// } -// if obsResult.IsError { -// t.Fatalf("MCP AddObservations returned error: %v", obsResult.Content) -// } -// if len(obsResult.StructuredContent.Observations) != 1 { -// t.Errorf("expected 1 observation result, got %d", len(obsResult.StructuredContent.Observations)) -// } - -// // Test SearchNodes through MCP -// searchParams := &mcp.CallToolParamsFor[SearchNodesArgs]{ -// Arguments: SearchNodesArgs{ -// Query: "coffee", -// }, -// } - -// searchResult, err := kb.SearchNodes(ctx, requestFor(serverSession, searchParams)) -// if err != nil { -// t.Fatalf("MCP SearchNodes failed: %v", err) -// } -// if searchResult.IsError { -// t.Fatalf("MCP SearchNodes returned error: %v", searchResult.Content) -// } -// if len(searchResult.StructuredContent.Entities) != 1 { -// t.Errorf("expected 1 entity from search, got %d", len(searchResult.StructuredContent.Entities)) -// } - -// // Test OpenNodes through MCP -// openParams := &mcp.CallToolParamsFor[OpenNodesArgs]{ -// Arguments: OpenNodesArgs{ -// Names: []string{"TestPerson"}, -// }, -// } - -// openResult, err := kb.OpenNodes(ctx, requestFor(serverSession, openParams)) -// if err != nil { -// t.Fatalf("MCP OpenNodes failed: %v", err) -// } -// if openResult.IsError { -// t.Fatalf("MCP OpenNodes returned error: %v", openResult.Content) -// } -// if len(openResult.StructuredContent.Entities) != 1 { -// t.Errorf("expected 1 entity from open, got %d", len(openResult.StructuredContent.Entities)) -// } - -// // Test DeleteObservations through MCP -// deleteObsParams := &mcp.CallToolParamsFor[DeleteObservationsArgs]{ -// Arguments: DeleteObservationsArgs{ -// Deletions: []Observation{ -// { -// EntityName: "TestPerson", -// Observations: []string{"Works remotely"}, -// }, -// }, -// }, -// } - -// deleteObsResult, err := kb.DeleteObservations(ctx, requestFor(serverSession, deleteObsParams)) -// if err != nil { -// t.Fatalf("MCP DeleteObservations failed: %v", err) -// } -// if deleteObsResult.IsError { -// t.Fatalf("MCP DeleteObservations returned error: %v", deleteObsResult.Content) -// } - -// // Test DeleteRelations through MCP -// deleteRelParams := &mcp.CallToolParamsFor[DeleteRelationsArgs]{ -// Arguments: DeleteRelationsArgs{ -// Relations: []Relation{ -// { -// From: "TestPerson", -// To: "Testing", -// RelationType: "likes", -// }, -// }, -// }, -// } - -// deleteRelResult, err := kb.DeleteRelations(ctx, requestFor(serverSession, deleteRelParams)) -// if err != nil { -// t.Fatalf("MCP DeleteRelations failed: %v", err) -// } -// if deleteRelResult.IsError { -// t.Fatalf("MCP DeleteRelations returned error: %v", deleteRelResult.Content) -// } - -// // Test DeleteEntities through MCP -// deleteEntParams := &mcp.CallToolParamsFor[DeleteEntitiesArgs]{ -// Arguments: DeleteEntitiesArgs{ -// EntityNames: []string{"TestPerson"}, -// }, -// } - -// deleteEntResult, err := kb.DeleteEntities(ctx, requestFor(serverSession, deleteEntParams)) -// if err != nil { -// t.Fatalf("MCP DeleteEntities failed: %v", err) -// } -// if deleteEntResult.IsError { -// t.Fatalf("MCP DeleteEntities returned error: %v", deleteEntResult.Content) -// } - -// // Verify final state -// finalRead, err := kb.ReadGraph(ctx, requestFor(serverSession, readParams)) -// if err != nil { -// t.Fatalf("Final MCP ReadGraph failed: %v", err) -// } -// if len(finalRead.StructuredContent.Entities) != 0 { -// t.Errorf("expected empty graph after deletion, got %d entities", len(finalRead.StructuredContent.Entities)) -// } -// }) -// } -// } +func TestMCPServerIntegration(t *testing.T) { + for name, newStore := range stores() { + t.Run(name, func(t *testing.T) { + s := newStore(t) + kb := knowledgeBase{s: s} + + // Create mock server session + ctx := context.Background() + + createResult, out, err := kb.CreateEntities(ctx, nil, CreateEntitiesArgs{ + Entities: []Entity{ + { + Name: "TestPerson", + EntityType: "Person", + Observations: []string{"Likes testing"}, + }, + }, + }) + if err != nil { + t.Fatalf("MCP CreateEntities failed: %v", err) + } + if createResult.IsError { + t.Fatalf("MCP CreateEntities returned error: %v", createResult.Content) + } + if len(out.Entities) != 1 { + t.Errorf("expected 1 entity created, got %d", len(out.Entities)) + } + + // Test ReadGraph through MCP + readResult, outg, err := kb.ReadGraph(ctx, nil, nil) + if err != nil { + t.Fatalf("MCP ReadGraph failed: %v", err) + } + if readResult.IsError { + t.Fatalf("MCP ReadGraph returned error: %v", readResult.Content) + } + if len(outg.Entities) != 1 { + t.Errorf("expected 1 entity in graph, got %d", len(outg.Entities)) + } + + relationsResult, outr, err := kb.CreateRelations(ctx, nil, CreateRelationsArgs{ + Relations: []Relation{ + { + From: "TestPerson", + To: "Testing", + RelationType: "likes", + }, + }, + }) + if err != nil { + t.Fatalf("MCP CreateRelations failed: %v", err) + } + if relationsResult.IsError { + t.Fatalf("MCP CreateRelations returned error: %v", relationsResult.Content) + } + if len(outr.Relations) != 1 { + t.Errorf("expected 1 relation created, got %d", len(outr.Relations)) + } + + obsResult, outo, err := kb.AddObservations(ctx, nil, AddObservationsArgs{ + Observations: []Observation{ + { + EntityName: "TestPerson", + Contents: []string{"Works remotely", "Drinks coffee"}, + }, + }, + }) + if err != nil { + t.Fatalf("MCP AddObservations failed: %v", err) + } + if obsResult.IsError { + t.Fatalf("MCP AddObservations returned error: %v", obsResult.Content) + } + if len(outo.Observations) != 1 { + t.Errorf("expected 1 observation result, got %d", len(outo.Observations)) + } + + searchResult, outg, err := kb.SearchNodes(ctx, nil, SearchNodesArgs{ + Query: "coffee", + }) + if err != nil { + t.Fatalf("MCP SearchNodes failed: %v", err) + } + if searchResult.IsError { + t.Fatalf("MCP SearchNodes returned error: %v", searchResult.Content) + } + if len(outg.Entities) != 1 { + t.Errorf("expected 1 entity from search, got %d", len(outg.Entities)) + } + + openResult, outg, err := kb.OpenNodes(ctx, nil, OpenNodesArgs{ + Names: []string{"TestPerson"}, + }) + if err != nil { + t.Fatalf("MCP OpenNodes failed: %v", err) + } + if openResult.IsError { + t.Fatalf("MCP OpenNodes returned error: %v", openResult.Content) + } + if len(outg.Entities) != 1 { + t.Errorf("expected 1 entity from open, got %d", len(outg.Entities)) + } + + deleteObsResult, _, err := kb.DeleteObservations(ctx, nil, DeleteObservationsArgs{ + Deletions: []Observation{ + { + EntityName: "TestPerson", + Observations: []string{"Works remotely"}, + }, + }, + }) + if err != nil { + t.Fatalf("MCP DeleteObservations failed: %v", err) + } + if deleteObsResult.IsError { + t.Fatalf("MCP DeleteObservations returned error: %v", deleteObsResult.Content) + } + + deleteRelResult, _, err := kb.DeleteRelations(ctx, nil, DeleteRelationsArgs{ + Relations: []Relation{ + { + From: "TestPerson", + To: "Testing", + RelationType: "likes", + }, + }, + }) + if err != nil { + t.Fatalf("MCP DeleteRelations failed: %v", err) + } + if deleteRelResult.IsError { + t.Fatalf("MCP DeleteRelations returned error: %v", deleteRelResult.Content) + } + + deleteEntResult, _, err := kb.DeleteEntities(ctx, nil, DeleteEntitiesArgs{ + EntityNames: []string{"TestPerson"}, + }) + if err != nil { + t.Fatalf("MCP DeleteEntities failed: %v", err) + } + if deleteEntResult.IsError { + t.Fatalf("MCP DeleteEntities returned error: %v", deleteEntResult.Content) + } + + // Verify final state + _, outg, err = kb.ReadGraph(ctx, nil, nil) + if err != nil { + t.Fatalf("Final MCP ReadGraph failed: %v", err) + } + if len(outg.Entities) != 0 { + t.Errorf("expected empty graph after deletion, got %d entities", len(outg.Entities)) + } + }) + } +} // TestMCPErrorHandling tests error scenarios through MCP layer. func TestMCPErrorHandling(t *testing.T) { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 28246fd3..0e9cf455 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -60,7 +60,7 @@ func TestStreamableTransports(t *testing.T) { return nil, nil, nil } AddTool(server, &Tool{Name: "hang"}, hang) - AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args map[string]any) (*CallToolResult, any, error) { + AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { // Test that we can make sampling requests during tool handling. // // Try this on both the request context and a background context, so From 43ad1eb8b4ebfbde45848afc4dc11ca27e218384 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Thu, 21 Aug 2025 10:57:28 -0400 Subject: [PATCH 122/221] mcp: treat pointers equivalently to non-pointers when deriving schema As reported in #199 and #200, the fact that we return a possibly "null" schema for pointer types breaks various clients, which expect schemas to be of type "object". This is an unfortunate footgun. For now, assume that the user wants us to treat pointers equivalently to non-pointers. If we want to change this behavior in the future, we can do so behind an option. + a test Also fix the handling of nil results in the case where the output schema is non-nil: we must provide structured content in this case. (This was causing the test to fail). Fixes #199 Fixes #200 --- mcp/mcp_test.go | 101 ++++++++++++++++++++++++++++++++++++++++++++++++ mcp/server.go | 83 ++++++++++++++++++++++++++++----------- 2 files changed, 161 insertions(+), 23 deletions(-) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 44dd76d2..446c7ba6 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -1084,3 +1084,104 @@ func TestNoDistributedDeadlock(t *testing.T) { } var testImpl = &Implementation{Name: "test", Version: "v1.0.0"} + +// This test checks that when we use pointer types for tools, we get the same +// schema as when using the non-pointer types. It is too much of a footgun for +// there to be a difference (see #199 and #200). +// +// If anyone asks, we can add an option that controls how pointers are treated. +func TestPointerArgEquivalence(t *testing.T) { + type input struct { + In string + } + type output struct { + Out string + } + cs, _ := basicConnection(t, func(s *Server) { + // Add two equivalent tools, one of which operates in the 'pointer' realm, + // the other of which does not. + // + // We handle a few different types of results, to assert they behave the + // same in all cases. + AddTool(s, &Tool{Name: "pointer"}, func(_ context.Context, req *ServerRequest[*CallToolParams], in *input) (*CallToolResult, *output, error) { + switch in.In { + case "": + return nil, nil, fmt.Errorf("must provide input") + case "nil": + return nil, nil, nil + case "empty": + return &CallToolResult{}, nil, nil + case "ok": + return &CallToolResult{}, &output{Out: "foo"}, nil + default: + panic("unreachable") + } + }) + AddTool(s, &Tool{Name: "nonpointer"}, func(_ context.Context, req *ServerRequest[*CallToolParams], in input) (*CallToolResult, output, error) { + switch in.In { + case "": + return nil, output{}, fmt.Errorf("must provide input") + case "nil": + return nil, output{}, nil + case "empty": + return &CallToolResult{}, output{}, nil + case "ok": + return &CallToolResult{}, output{Out: "foo"}, nil + default: + panic("unreachable") + } + }) + }) + defer cs.Close() + + ctx := context.Background() + tools, err := cs.ListTools(ctx, nil) + if err != nil { + t.Fatal(err) + } + if got, want := len(tools.Tools), 2; got != want { + t.Fatalf("got %d tools, want %d", got, want) + } + t0 := tools.Tools[0] + t1 := tools.Tools[1] + + // First, check that the tool schemas don't differ. + if diff := cmp.Diff(t0.InputSchema, t1.InputSchema); diff != "" { + t.Errorf("input schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff) + } + if diff := cmp.Diff(t0.OutputSchema, t1.OutputSchema); diff != "" { + t.Errorf("output schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff) + } + + // Then, check that we handle empty input equivalently. + for _, args := range []any{nil, struct{}{}} { + r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: args}) + if err != nil { + t.Fatal(err) + } + r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: args}) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(r0, r1); diff != "" { + t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff) + } + } + + // Then, check that we handle different types of output equivalently. + for _, in := range []string{"nil", "empty", "ok"} { + t.Run(in, func(t *testing.T) { + r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: input{In: in}}) + if err != nil { + t.Fatal(err) + } + r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: input{In: in}}) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(r0, r1); diff != "" { + t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff) + } + }) + } +} diff --git a/mcp/server.go b/mcp/server.go index b8e72907..13ecb079 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -189,31 +189,26 @@ func toolFor[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandle // TODO(v0.3.0): test func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { - var err error tt := *t - tt.InputSchema = t.InputSchema - if tt.InputSchema == nil { - tt.InputSchema, err = jsonschema.For[In](nil) + var inputResolved *jsonschema.Resolved + if _, err := setSchema[In](&tt.InputSchema, &inputResolved); err != nil { + return nil, nil, fmt.Errorf("input schema: %w", err) + } + + // Handling for zero values: + // + // If Out is a pointer type and we've derived the output schema from its + // element type, use the zero value of its element type in place of a typed + // nil. + var ( + elemZero any // only non-nil if Out is a pointer type + outputResolved *jsonschema.Resolved + ) + if reflect.TypeFor[Out]() != reflect.TypeFor[any]() { + var err error + elemZero, err = setSchema[Out](&t.OutputSchema, &outputResolved) if err != nil { - return nil, nil, fmt.Errorf("input schema: %w", err) - } - } - inputResolved, err := tt.InputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) - if err != nil { - return nil, nil, fmt.Errorf("resolving input schema: %w", err) - } - - if tt.OutputSchema == nil && reflect.TypeFor[Out]() != reflect.TypeFor[any]() { - tt.OutputSchema, err = jsonschema.For[Out](nil) - } - if err != nil { - return nil, nil, fmt.Errorf("output schema: %w", err) - } - var outputResolved *jsonschema.Resolved - if tt.OutputSchema != nil { - outputResolved, err = tt.OutputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) - if err != nil { - return nil, nil, fmt.Errorf("resolving output schema: %w", err) + return nil, nil, fmt.Errorf("output schema: %v", err) } } @@ -255,12 +250,54 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan res = &CallToolResult{} } res.StructuredContent = out + if elemZero != nil { + // Avoid typed nil, which will serialize as JSON null. + // Instead, use the zero value of the non-zero + var z Out + if any(out) == any(z) { // zero is only non-nil if Out is a pointer type + res.StructuredContent = elemZero + } + } + if tt.OutputSchema != nil && elemZero != nil { + res.StructuredContent = elemZero + } return res, nil } return &tt, th, nil } +// setSchema sets the schema and resolved schema corresponding to the type T. +// +// If sfield is nil, the schema is derived from T. +// +// Pointers are treated equivalently to non-pointers when deriving the schema. +// If an indirection occurred to derive the schema, a non-nil zero value is +// returned to be used in place of the typed nil zero value. +// +// Note that if sfield already holds a schema, zero will be nil even if T is a +// pointer: if the user provided the schema, they may have intentionally +// derived it from the pointer type, and handling of zero values is up to them. +// +// TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we +// should have a jsonschema.Zero(schema) helper? +func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) (zero any, err error) { + rt := reflect.TypeFor[T]() + if *sfield == nil { + if rt.Kind() == reflect.Pointer { + rt = rt.Elem() + zero = reflect.Zero(rt).Interface() + } + // TODO: we should be able to pass nil opts here. + *sfield, err = jsonschema.ForType(rt, &jsonschema.ForOptions{}) + } + if err != nil { + return zero, err + } + *rfield, err = (*sfield).Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) + return zero, err +} + // AddTool adds a tool and handler to the server. // // A shallow copy of the tool is made first. From 6378df6f24fd069af5336e1679ce057335b1aa0d Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 21 Aug 2025 12:10:53 -0400 Subject: [PATCH 123/221] all: remove ServerRequest[T] for concrete T (#342) Replace all occurrences of ServerRequest[*CallToolParams] and other concrete instantiations with CallToolRequest and the like. Make the XXXRequest types aliases, to preserve the convenience of generics for the internal machinery (see shared.go, for example.) I will expand the aliases in a followup PR. --- README.md | 2 +- examples/server/completion/main.go | 2 +- examples/server/custom-transport/main.go | 2 +- examples/server/hello/main.go | 4 +- examples/server/memory/kb.go | 18 ++++----- examples/server/sequentialthinking/main.go | 8 ++-- .../server/sequentialthinking/main_test.go | 22 ++++------- examples/server/sse/main.go | 2 +- internal/readme/server/server.go | 2 +- mcp/example_middleware_test.go | 2 +- mcp/mcp_test.go | 28 +++++++------- mcp/requests.go | 24 ++++++++++++ mcp/resource.go | 2 +- mcp/server.go | 38 +++++++++---------- mcp/server_example_test.go | 2 +- mcp/server_test.go | 12 +++--- mcp/shared_test.go | 2 +- mcp/sse_example_test.go | 2 +- mcp/streamable_test.go | 14 +++---- mcp/tool.go | 4 +- mcp/tool_test.go | 4 +- 21 files changed, 107 insertions(+), 89 deletions(-) create mode 100644 mcp/requests.go diff --git a/README.md b/README.md index b46724b7..4da0ac61 100644 --- a/README.md +++ b/README.md @@ -115,7 +115,7 @@ type HiParams struct { Name string `json:"name" jsonschema:"the name of the person to greet"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args HiParams) (*mcp.CallToolResult, any, error) { +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args HiParams) (*mcp.CallToolResult, any, error) { return &mcp.CallToolResult{ Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + args.Name}}, }, nil, nil diff --git a/examples/server/completion/main.go b/examples/server/completion/main.go index f05b2721..b0a991fd 100644 --- a/examples/server/completion/main.go +++ b/examples/server/completion/main.go @@ -16,7 +16,7 @@ import ( // a CompletionHandler to an MCP Server's options. func main() { // Define your custom CompletionHandler logic. - myCompletionHandler := func(_ context.Context, req *mcp.ServerRequest[*mcp.CompleteParams]) (*mcp.CompleteResult, error) { + myCompletionHandler := func(_ context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { // In a real application, you'd implement actual completion logic here. // For this example, we return a fixed set of suggestions. var suggestions []string diff --git a/examples/server/custom-transport/main.go b/examples/server/custom-transport/main.go index 72cfc31d..c367cb62 100644 --- a/examples/server/custom-transport/main.go +++ b/examples/server/custom-transport/main.go @@ -85,7 +85,7 @@ type HiArgs struct { } // SayHi is a tool handler that responds with a greeting. -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args HiArgs) (*mcp.CallToolResult, struct{}, error) { +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args HiArgs) (*mcp.CallToolResult, struct{}, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: "Hi " + args.Name}, diff --git a/examples/server/hello/main.go b/examples/server/hello/main.go index f71b0a78..04c0e0b4 100644 --- a/examples/server/hello/main.go +++ b/examples/server/hello/main.go @@ -22,7 +22,7 @@ type HiArgs struct { Name string `json:"name" jsonschema:"the name to say hi to"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args HiArgs) (*mcp.CallToolResult, any, error) { +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args HiArgs) (*mcp.CallToolResult, any, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: "Hi " + args.Name}, @@ -69,7 +69,7 @@ var embeddedResources = map[string]string{ "info": "This is the hello example server.", } -func handleEmbeddedResource(_ context.Context, req *mcp.ServerRequest[*mcp.ReadResourceParams]) (*mcp.ReadResourceResult, error) { +func handleEmbeddedResource(_ context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { u, err := url.Parse(req.Params.URI) if err != nil { return nil, err diff --git a/examples/server/memory/kb.go b/examples/server/memory/kb.go index e28a4909..c6a59ec0 100644 --- a/examples/server/memory/kb.go +++ b/examples/server/memory/kb.go @@ -431,7 +431,7 @@ func (k knowledgeBase) openNodes(names []string) (KnowledgeGraph, error) { }, nil } -func (k knowledgeBase) CreateEntities(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args CreateEntitiesArgs) (*mcp.CallToolResult, CreateEntitiesResult, error) { +func (k knowledgeBase) CreateEntities(ctx context.Context, req *mcp.CallToolRequest, args CreateEntitiesArgs) (*mcp.CallToolResult, CreateEntitiesResult, error) { var res mcp.CallToolResult entities, err := k.createEntities(args.Entities) @@ -450,7 +450,7 @@ func (k knowledgeBase) CreateEntities(ctx context.Context, req *mcp.ServerReques return &res, CreateEntitiesResult{Entities: entities}, nil } -func (k knowledgeBase) CreateRelations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args CreateRelationsArgs) (*mcp.CallToolResult, CreateRelationsResult, error) { +func (k knowledgeBase) CreateRelations(ctx context.Context, req *mcp.CallToolRequest, args CreateRelationsArgs) (*mcp.CallToolResult, CreateRelationsResult, error) { var res mcp.CallToolResult relations, err := k.createRelations(args.Relations) @@ -465,7 +465,7 @@ func (k knowledgeBase) CreateRelations(ctx context.Context, req *mcp.ServerReque return &res, CreateRelationsResult{Relations: relations}, nil } -func (k knowledgeBase) AddObservations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args AddObservationsArgs) (*mcp.CallToolResult, AddObservationsResult, error) { +func (k knowledgeBase) AddObservations(ctx context.Context, req *mcp.CallToolRequest, args AddObservationsArgs) (*mcp.CallToolResult, AddObservationsResult, error) { var res mcp.CallToolResult observations, err := k.addObservations(args.Observations) @@ -482,7 +482,7 @@ func (k knowledgeBase) AddObservations(ctx context.Context, req *mcp.ServerReque }, nil } -func (k knowledgeBase) DeleteEntities(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args DeleteEntitiesArgs) (*mcp.CallToolResult, any, error) { +func (k knowledgeBase) DeleteEntities(ctx context.Context, req *mcp.CallToolRequest, args DeleteEntitiesArgs) (*mcp.CallToolResult, any, error) { var res mcp.CallToolResult err := k.deleteEntities(args.EntityNames) @@ -497,7 +497,7 @@ func (k knowledgeBase) DeleteEntities(ctx context.Context, req *mcp.ServerReques return &res, nil, nil } -func (k knowledgeBase) DeleteObservations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args DeleteObservationsArgs) (*mcp.CallToolResult, any, error) { +func (k knowledgeBase) DeleteObservations(ctx context.Context, req *mcp.CallToolRequest, args DeleteObservationsArgs) (*mcp.CallToolResult, any, error) { var res mcp.CallToolResult err := k.deleteObservations(args.Deletions) @@ -512,7 +512,7 @@ func (k knowledgeBase) DeleteObservations(ctx context.Context, req *mcp.ServerRe return &res, nil, nil } -func (k knowledgeBase) DeleteRelations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args DeleteRelationsArgs) (*mcp.CallToolResult, struct{}, error) { +func (k knowledgeBase) DeleteRelations(ctx context.Context, req *mcp.CallToolRequest, args DeleteRelationsArgs) (*mcp.CallToolResult, struct{}, error) { var res mcp.CallToolResult err := k.deleteRelations(args.Relations) @@ -527,7 +527,7 @@ func (k knowledgeBase) DeleteRelations(ctx context.Context, req *mcp.ServerReque return &res, struct{}{}, nil } -func (k knowledgeBase) ReadGraph(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args any) (*mcp.CallToolResult, KnowledgeGraph, error) { +func (k knowledgeBase) ReadGraph(ctx context.Context, req *mcp.CallToolRequest, args any) (*mcp.CallToolResult, KnowledgeGraph, error) { var res mcp.CallToolResult graph, err := k.loadGraph() @@ -542,7 +542,7 @@ func (k knowledgeBase) ReadGraph(ctx context.Context, req *mcp.ServerRequest[*mc return &res, graph, nil } -func (k knowledgeBase) SearchNodes(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args SearchNodesArgs) (*mcp.CallToolResult, KnowledgeGraph, error) { +func (k knowledgeBase) SearchNodes(ctx context.Context, req *mcp.CallToolRequest, args SearchNodesArgs) (*mcp.CallToolResult, KnowledgeGraph, error) { var res mcp.CallToolResult graph, err := k.searchNodes(args.Query) @@ -557,7 +557,7 @@ func (k knowledgeBase) SearchNodes(ctx context.Context, req *mcp.ServerRequest[* return &res, graph, nil } -func (k knowledgeBase) OpenNodes(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args OpenNodesArgs) (*mcp.CallToolResult, KnowledgeGraph, error) { +func (k knowledgeBase) OpenNodes(ctx context.Context, req *mcp.CallToolRequest, args OpenNodesArgs) (*mcp.CallToolResult, KnowledgeGraph, error) { var res mcp.CallToolResult graph, err := k.openNodes(args.Names) diff --git a/examples/server/sequentialthinking/main.go b/examples/server/sequentialthinking/main.go index 100e1167..e0ae5219 100644 --- a/examples/server/sequentialthinking/main.go +++ b/examples/server/sequentialthinking/main.go @@ -231,7 +231,7 @@ func deepCopyThoughts(thoughts []*Thought) []*Thought { } // StartThinking begins a new sequential thinking session for a complex problem. -func StartThinking(ctx context.Context, _ *mcp.ServerRequest[*mcp.CallToolParams], args StartThinkingArgs) (*mcp.CallToolResult, any, error) { +func StartThinking(ctx context.Context, _ *mcp.CallToolRequest, args StartThinkingArgs) (*mcp.CallToolResult, any, error) { sessionID := args.SessionID if sessionID == "" { sessionID = randText() @@ -264,7 +264,7 @@ func StartThinking(ctx context.Context, _ *mcp.ServerRequest[*mcp.CallToolParams } // ContinueThinking adds the next thought step, revises a previous step, or creates a branch in the thinking process. -func ContinueThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args ContinueThinkingArgs) (*mcp.CallToolResult, any, error) { +func ContinueThinking(ctx context.Context, req *mcp.CallToolRequest, args ContinueThinkingArgs) (*mcp.CallToolResult, any, error) { // Handle revision of existing thought if args.ReviseStep != nil { err := store.CompareAndSwap(args.SessionID, func(session *ThinkingSession) (*ThinkingSession, error) { @@ -391,7 +391,7 @@ func ContinueThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolP } // ReviewThinking provides a complete review of the thinking process for a session. -func ReviewThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args ReviewThinkingArgs) (*mcp.CallToolResult, any, error) { +func ReviewThinking(ctx context.Context, req *mcp.CallToolRequest, args ReviewThinkingArgs) (*mcp.CallToolResult, any, error) { // Get a snapshot of the session to avoid race conditions sessionSnapshot, exists := store.SessionSnapshot(args.SessionID) if !exists { @@ -428,7 +428,7 @@ func ReviewThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolPar } // ThinkingHistory handles resource requests for thinking session data and history. -func ThinkingHistory(ctx context.Context, req *mcp.ServerRequest[*mcp.ReadResourceParams]) (*mcp.ReadResourceResult, error) { +func ThinkingHistory(ctx context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { // Extract session ID from URI (e.g., "thinking://session_123") u, err := url.Parse(req.Params.URI) if err != nil { diff --git a/examples/server/sequentialthinking/main_test.go b/examples/server/sequentialthinking/main_test.go index 8889db7d..2655114c 100644 --- a/examples/server/sequentialthinking/main_test.go +++ b/examples/server/sequentialthinking/main_test.go @@ -387,11 +387,11 @@ func TestThinkingHistory(t *testing.T) { ctx := context.Background() // Test listing all sessions - listParams := &mcp.ReadResourceParams{ - URI: "thinking://sessions", - } - - result, err := ThinkingHistory(ctx, requestFor(listParams)) + result, err := ThinkingHistory(ctx, &mcp.ReadResourceRequest{ + Params: &mcp.ReadResourceParams{ + URI: "thinking://sessions", + }, + }) if err != nil { t.Fatalf("ThinkingHistory() error = %v", err) } @@ -417,11 +417,9 @@ func TestThinkingHistory(t *testing.T) { } // Test getting specific session - sessionParams := &mcp.ReadResourceParams{ - URI: "thinking://session1", - } - - result, err = ThinkingHistory(ctx, requestFor(sessionParams)) + result, err = ThinkingHistory(ctx, &mcp.ReadResourceRequest{ + Params: &mcp.ReadResourceParams{URI: "thinking://session1"}, + }) if err != nil { t.Fatalf("ThinkingHistory() error = %v", err) } @@ -491,7 +489,3 @@ func TestInvalidOperations(t *testing.T) { t.Error("Expected error for invalid revision step") } } - -func requestFor[P mcp.Params](p P) *mcp.ServerRequest[P] { - return &mcp.ServerRequest[P]{Params: p} -} diff --git a/examples/server/sse/main.go b/examples/server/sse/main.go index c2603b41..27f9caed 100644 --- a/examples/server/sse/main.go +++ b/examples/server/sse/main.go @@ -24,7 +24,7 @@ type SayHiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args SayHiParams) (*mcp.CallToolResult, any, error) { +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args SayHiParams) (*mcp.CallToolResult, any, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: "Hi " + args.Name}, diff --git a/internal/readme/server/server.go b/internal/readme/server/server.go index 087992e8..aff5fcd0 100644 --- a/internal/readme/server/server.go +++ b/internal/readme/server/server.go @@ -16,7 +16,7 @@ type HiParams struct { Name string `json:"name" jsonschema:"the name of the person to greet"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args HiParams) (*mcp.CallToolResult, any, error) { +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args HiParams) (*mcp.CallToolResult, any, error) { return &mcp.CallToolResult{ Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + args.Name}}, }, nil, nil diff --git a/mcp/example_middleware_test.go b/mcp/example_middleware_test.go index 0f6d540e..10dda0fa 100644 --- a/mcp/example_middleware_test.go +++ b/mcp/example_middleware_test.go @@ -89,7 +89,7 @@ func Example_loggingMiddleware() { }, func( ctx context.Context, - req *mcp.ServerRequest[*mcp.CallToolParams], args map[string]any, + req *mcp.CallToolRequest, args map[string]any, ) (*mcp.CallToolResult, any, error) { name, ok := args["name"].(string) if !ok { diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 446c7ba6..42aa06af 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -33,7 +33,7 @@ type hiParams struct { // TODO(jba): after schemas are stateless (WIP), this can be a variable. func greetTool() *Tool { return &Tool{Name: "greet", Description: "say hi"} } -func sayHi(ctx context.Context, req *ServerRequest[*CallToolParams], args hiParams) (*CallToolResult, any, error) { +func sayHi(ctx context.Context, req *CallToolRequest, args hiParams) (*CallToolResult, any, error) { if err := req.Session.Ping(ctx, nil); err != nil { return nil, nil, fmt.Errorf("ping failed: %v", err) } @@ -74,20 +74,20 @@ func TestEndToEnd(t *testing.T) { } sopts := &ServerOptions{ - InitializedHandler: func(context.Context, *ServerRequest[*InitializedParams]) { + InitializedHandler: func(context.Context, *InitializedRequest) { notificationChans["initialized"] <- 0 }, - RootsListChangedHandler: func(context.Context, *ServerRequest[*RootsListChangedParams]) { + RootsListChangedHandler: func(context.Context, *RootsListChangedRequest) { notificationChans["roots"] <- 0 }, - ProgressNotificationHandler: func(context.Context, *ServerRequest[*ProgressNotificationParams]) { + ProgressNotificationHandler: func(context.Context, *ProgressNotificationRequest) { notificationChans["progress_server"] <- 0 }, - SubscribeHandler: func(context.Context, *ServerRequest[*SubscribeParams]) error { + SubscribeHandler: func(context.Context, *SubscribeRequest) error { notificationChans["subscribe"] <- 0 return nil }, - UnsubscribeHandler: func(context.Context, *ServerRequest[*UnsubscribeParams]) error { + UnsubscribeHandler: func(context.Context, *UnsubscribeRequest) error { notificationChans["unsubscribe"] <- 0 return nil }, @@ -98,7 +98,7 @@ func TestEndToEnd(t *testing.T) { Description: "say hi", }, sayHi) AddTool(s, &Tool{Name: "fail", InputSchema: &jsonschema.Schema{}}, - func(context.Context, *ServerRequest[*CallToolParams], map[string]any) (*CallToolResult, any, error) { + func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) { return nil, nil, errTestFailure }) s.AddPrompt(codeReviewPrompt, codReviewPromptHandler) @@ -511,7 +511,7 @@ var embeddedResources = map[string]string{ "info": "This is the MCP test server.", } -func handleEmbeddedResource(_ context.Context, req *ServerRequest[*ReadResourceParams]) (*ReadResourceResult, error) { +func handleEmbeddedResource(_ context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { u, err := url.Parse(req.Params.URI) if err != nil { return nil, err @@ -663,7 +663,7 @@ func TestCancellation(t *testing.T) { start = make(chan struct{}) cancelled = make(chan struct{}, 1) // don't block the request ) - slowRequest := func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { + slowRequest := func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { start <- struct{}{} select { case <-ctx.Done(): @@ -852,7 +852,7 @@ func traceCalls[S Session](w io.Writer, prefix string) Middleware { } } -func nopHandler(context.Context, *ServerRequest[*CallToolParams]) (*CallToolResult, error) { +func nopHandler(context.Context, *CallToolRequest) (*CallToolResult, error) { return nil, nil } @@ -1009,13 +1009,13 @@ func TestSynchronousNotifications(t *testing.T) { var rootsChanged atomic.Bool serverOpts := &ServerOptions{ - RootsListChangedHandler: func(_ context.Context, req *ServerRequest[*RootsListChangedParams]) { + RootsListChangedHandler: func(_ context.Context, req *RootsListChangedRequest) { rootsChanged.Store(true) }, } server := NewServer(testImpl, serverOpts) cs, ss := basicClientServerConnection(t, client, server, func(s *Server) { - AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { + AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { if !rootsChanged.Load() { return nil, nil, fmt.Errorf("didn't get root change notification") } @@ -1064,11 +1064,11 @@ func TestNoDistributedDeadlock(t *testing.T) { } client := NewClient(testImpl, clientOpts) cs, _ := basicClientServerConnection(t, client, nil, func(s *Server) { - AddTool(s, &Tool{Name: "tool1"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { + AddTool(s, &Tool{Name: "tool1"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { req.Session.CreateMessage(ctx, new(CreateMessageParams)) return new(CallToolResult), nil, nil }) - AddTool(s, &Tool{Name: "tool2"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { + AddTool(s, &Tool{Name: "tool2"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { req.Session.Ping(ctx, nil) return new(CallToolResult), nil, nil }) diff --git a/mcp/requests.go b/mcp/requests.go new file mode 100644 index 00000000..ceed5026 --- /dev/null +++ b/mcp/requests.go @@ -0,0 +1,24 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file holds the request types. + +package mcp + +// TODO: expand the aliases +type ( + CallToolRequest = ServerRequest[*CallToolParams] + CompleteRequest = ServerRequest[*CompleteParams] + GetPromptRequest = ServerRequest[*GetPromptParams] + InitializedRequest = ServerRequest[*InitializedParams] + ListPromptsRequest = ServerRequest[*ListPromptsParams] + ListResourcesRequest = ServerRequest[*ListResourcesParams] + ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams] + ListToolsRequest = ServerRequest[*ListToolsParams] + ProgressNotificationRequest = ServerRequest[*ProgressNotificationParams] + ReadResourceRequest = ServerRequest[*ReadResourceParams] + RootsListChangedRequest = ServerRequest[*RootsListChangedParams] + SubscribeRequest = ServerRequest[*SubscribeParams] + UnsubscribeRequest = ServerRequest[*UnsubscribeParams] +) diff --git a/mcp/resource.go b/mcp/resource.go index 5445715b..0658c661 100644 --- a/mcp/resource.go +++ b/mcp/resource.go @@ -35,7 +35,7 @@ type serverResourceTemplate struct { // A ResourceHandler is a function that reads a resource. // It will be called when the client calls [ClientSession.ReadResource]. // If it cannot find the resource, it should return the result of calling [ResourceNotFoundError]. -type ResourceHandler func(context.Context, *ServerRequest[*ReadResourceParams]) (*ReadResourceResult, error) +type ResourceHandler func(context.Context, *ReadResourceRequest) (*ReadResourceResult, error) // ResourceNotFoundError returns an error indicating that a resource being read could // not be found. diff --git a/mcp/server.go b/mcp/server.go index 13ecb079..7af83824 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -54,24 +54,24 @@ type ServerOptions struct { // Optional instructions for connected clients. Instructions string // If non-nil, called when "notifications/initialized" is received. - InitializedHandler func(context.Context, *ServerRequest[*InitializedParams]) + InitializedHandler func(context.Context, *InitializedRequest) // PageSize is the maximum number of items to return in a single page for // list methods (e.g. ListTools). PageSize int // If non-nil, called when "notifications/roots/list_changed" is received. - RootsListChangedHandler func(context.Context, *ServerRequest[*RootsListChangedParams]) + RootsListChangedHandler func(context.Context, *RootsListChangedRequest) // If non-nil, called when "notifications/progress" is received. - ProgressNotificationHandler func(context.Context, *ServerRequest[*ProgressNotificationParams]) + ProgressNotificationHandler func(context.Context, *ProgressNotificationRequest) // If non-nil, called when "completion/complete" is received. - CompletionHandler func(context.Context, *ServerRequest[*CompleteParams]) (*CompleteResult, error) + CompletionHandler func(context.Context, *CompleteRequest) (*CompleteResult, error) // If non-zero, defines an interval for regular "ping" requests. // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. KeepAlive time.Duration // Function called when a client session subscribes to a resource. - SubscribeHandler func(context.Context, *ServerRequest[*SubscribeParams]) error + SubscribeHandler func(context.Context, *SubscribeRequest) error // Function called when a client session unsubscribes from a resource. - UnsubscribeHandler func(context.Context, *ServerRequest[*UnsubscribeParams]) error + UnsubscribeHandler func(context.Context, *UnsubscribeRequest) error // If true, advertises the prompts capability during initialization, // even if no prompts have been registered. HasPrompts bool @@ -212,7 +212,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan } } - th := func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { // Unmarshal and validate args. rawArgs := req.Params.Arguments.(json.RawMessage) var in In @@ -395,7 +395,7 @@ func (s *Server) capabilities() *ServerCapabilities { return caps } -func (s *Server) complete(ctx context.Context, req *ServerRequest[*CompleteParams]) (Result, error) { +func (s *Server) complete(ctx context.Context, req *CompleteRequest) (Result, error) { if s.opts.CompletionHandler == nil { return nil, jsonrpc2.ErrMethodNotFound } @@ -424,7 +424,7 @@ func (s *Server) Sessions() iter.Seq[*ServerSession] { return slices.Values(clients) } -func (s *Server) listPrompts(_ context.Context, req *ServerRequest[*ListPromptsParams]) (*ListPromptsResult, error) { +func (s *Server) listPrompts(_ context.Context, req *ListPromptsRequest) (*ListPromptsResult, error) { s.mu.Lock() defer s.mu.Unlock() if req.Params == nil { @@ -438,7 +438,7 @@ func (s *Server) listPrompts(_ context.Context, req *ServerRequest[*ListPromptsP }) } -func (s *Server) getPrompt(ctx context.Context, req *ServerRequest[*GetPromptParams]) (*GetPromptResult, error) { +func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetPromptResult, error) { s.mu.Lock() prompt, ok := s.prompts.get(req.Params.Name) s.mu.Unlock() @@ -452,7 +452,7 @@ func (s *Server) getPrompt(ctx context.Context, req *ServerRequest[*GetPromptPar return prompt.handler(ctx, req.Session, req.Params) } -func (s *Server) listTools(_ context.Context, req *ServerRequest[*ListToolsParams]) (*ListToolsResult, error) { +func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListToolsResult, error) { s.mu.Lock() defer s.mu.Unlock() if req.Params == nil { @@ -466,7 +466,7 @@ func (s *Server) listTools(_ context.Context, req *ServerRequest[*ListToolsParam }) } -func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { +func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { s.mu.Lock() st, ok := s.tools.get(req.Params.Name) s.mu.Unlock() @@ -481,7 +481,7 @@ func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParam return st.handler(ctx, req) } -func (s *Server) listResources(_ context.Context, req *ServerRequest[*ListResourcesParams]) (*ListResourcesResult, error) { +func (s *Server) listResources(_ context.Context, req *ListResourcesRequest) (*ListResourcesResult, error) { s.mu.Lock() defer s.mu.Unlock() if req.Params == nil { @@ -495,7 +495,7 @@ func (s *Server) listResources(_ context.Context, req *ServerRequest[*ListResour }) } -func (s *Server) listResourceTemplates(_ context.Context, req *ServerRequest[*ListResourceTemplatesParams]) (*ListResourceTemplatesResult, error) { +func (s *Server) listResourceTemplates(_ context.Context, req *ListResourceTemplatesRequest) (*ListResourceTemplatesResult, error) { s.mu.Lock() defer s.mu.Unlock() if req.Params == nil { @@ -510,7 +510,7 @@ func (s *Server) listResourceTemplates(_ context.Context, req *ServerRequest[*Li }) } -func (s *Server) readResource(ctx context.Context, req *ServerRequest[*ReadResourceParams]) (*ReadResourceResult, error) { +func (s *Server) readResource(ctx context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { uri := req.Params.URI // Look up the resource URI in the lists of resources and resource templates. // This is a security check as well as an information lookup. @@ -575,7 +575,7 @@ func fileResourceHandler(dir string) ResourceHandler { if err != nil { panic(err) } - return func(ctx context.Context, req *ServerRequest[*ReadResourceParams]) (_ *ReadResourceResult, err error) { + return func(ctx context.Context, req *ReadResourceRequest) (_ *ReadResourceResult, err error) { defer util.Wrapf(&err, "reading resource %s", req.Params.URI) // TODO(#25): use a memoizing API here. @@ -610,7 +610,7 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot return nil } -func (s *Server) subscribe(ctx context.Context, req *ServerRequest[*SubscribeParams]) (*emptyResult, error) { +func (s *Server) subscribe(ctx context.Context, req *SubscribeRequest) (*emptyResult, error) { if s.opts.SubscribeHandler == nil { return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound) } @@ -628,7 +628,7 @@ func (s *Server) subscribe(ctx context.Context, req *ServerRequest[*SubscribePar return &emptyResult{}, nil } -func (s *Server) unsubscribe(ctx context.Context, req *ServerRequest[*UnsubscribeParams]) (*emptyResult, error) { +func (s *Server) unsubscribe(ctx context.Context, req *UnsubscribeRequest) (*emptyResult, error) { if s.opts.UnsubscribeHandler == nil { return nil, jsonrpc2.ErrMethodNotFound } @@ -762,7 +762,7 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar return nil, nil } -func (s *Server) callRootsListChangedHandler(ctx context.Context, req *ServerRequest[*RootsListChangedParams]) (Result, error) { +func (s *Server) callRootsListChangedHandler(ctx context.Context, req *RootsListChangedRequest) (Result, error) { if h := s.opts.RootsListChangedHandler; h != nil { h(ctx, req) } diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index 2b4a0bf1..e68dc308 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -16,7 +16,7 @@ type SayHiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args SayHiParams) (*mcp.CallToolResult, any, error) { +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args SayHiParams) (*mcp.CallToolResult, any, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: "Hi " + args.Name}, diff --git a/mcp/server_test.go b/mcp/server_test.go index 39a4cdb4..1ed4c3cc 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -282,10 +282,10 @@ func TestServerCapabilities(t *testing.T) { s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) }, serverOpts: ServerOptions{ - SubscribeHandler: func(context.Context, *ServerRequest[*SubscribeParams]) error { + SubscribeHandler: func(context.Context, *SubscribeRequest) error { return nil }, - UnsubscribeHandler: func(context.Context, *ServerRequest[*UnsubscribeParams]) error { + UnsubscribeHandler: func(context.Context, *UnsubscribeRequest) error { return nil }, }, @@ -308,7 +308,7 @@ func TestServerCapabilities(t *testing.T) { name: "With completions", configureServer: func(s *Server) {}, serverOpts: ServerOptions{ - CompletionHandler: func(context.Context, *ServerRequest[*CompleteParams]) (*CompleteResult, error) { + CompletionHandler: func(context.Context, *CompleteRequest) (*CompleteResult, error) { return nil, nil }, }, @@ -326,13 +326,13 @@ func TestServerCapabilities(t *testing.T) { s.AddTool(tool, nil) }, serverOpts: ServerOptions{ - SubscribeHandler: func(context.Context, *ServerRequest[*SubscribeParams]) error { + SubscribeHandler: func(context.Context, *SubscribeRequest) error { return nil }, - UnsubscribeHandler: func(context.Context, *ServerRequest[*UnsubscribeParams]) error { + UnsubscribeHandler: func(context.Context, *UnsubscribeRequest) error { return nil }, - CompletionHandler: func(context.Context, *ServerRequest[*CompleteParams]) (*CompleteResult, error) { + CompletionHandler: func(context.Context, *CompleteRequest) (*CompleteResult, error) { return nil, nil }, }, diff --git a/mcp/shared_test.go b/mcp/shared_test.go index 4d0859ac..23818f87 100644 --- a/mcp/shared_test.go +++ b/mcp/shared_test.go @@ -15,7 +15,7 @@ package mcp // P *int `json:",omitempty"` // } -// dummyHandler := func(context.Context, *ServerRequest[*CallToolParams], req) (*CallToolResultFor[any], error) { +// dummyHandler := func(context.Context, *CallToolRequest, req) (*CallToolResultFor[any], error) { // return nil, nil // } diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index 93ccf788..7d777114 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -18,7 +18,7 @@ type AddParams struct { X, Y int } -func Add(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args AddParams) (*mcp.CallToolResult, any, error) { +func Add(ctx context.Context, req *mcp.CallToolRequest, args AddParams) (*mcp.CallToolResult, any, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: fmt.Sprintf("%d", args.X+args.Y)}, diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 0e9cf455..5cd04eca 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -49,7 +49,7 @@ func TestStreamableTransports(t *testing.T) { start = make(chan struct{}) cancelled = make(chan struct{}, 1) // don't block the request ) - hang := func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { + hang := func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { start <- struct{}{} select { case <-ctx.Done(): @@ -60,7 +60,7 @@ func TestStreamableTransports(t *testing.T) { return nil, nil, nil } AddTool(server, &Tool{Name: "hang"}, hang) - AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { + AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { // Test that we can make sampling requests during tool handling. // // Try this on both the request context and a background context, so @@ -220,7 +220,7 @@ func testClientReplay(t *testing.T, test clientReplayTest) { serverReadyToKillProxy := make(chan struct{}) serverClosed := make(chan struct{}) AddTool(server, &Tool{Name: "multiMessageTool", InputSchema: &jsonschema.Schema{}}, - func(ctx context.Context, req *ServerRequest[*CallToolParams], args map[string]any) (*CallToolResult, any, error) { + func(ctx context.Context, req *CallToolRequest, args map[string]any) (*CallToolResult, any, error) { // Send one message to the request context, and another to a background // context (which will end up on the hanging GET). @@ -354,7 +354,7 @@ func TestServerInitiatedSSE(t *testing.T) { } defer clientSession.Close() AddTool(server, &Tool{Name: "testTool", InputSchema: &jsonschema.Schema{}}, - func(context.Context, *ServerRequest[*CallToolParams], map[string]any) (*CallToolResult, any, error) { + func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) { return &CallToolResult{}, nil, nil }) receivedNotifications := readNotifications(t, ctx, notifications, 1) @@ -659,7 +659,7 @@ func TestStreamableServerTransport(t *testing.T) { server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) server.AddTool( &Tool{Name: "tool", InputSchema: &jsonschema.Schema{}}, - func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { if test.tool != nil { test.tool(t, ctx, req.Session) } @@ -1072,7 +1072,7 @@ func TestEventID(t *testing.T) { func TestStreamableStateless(t *testing.T) { // This version of sayHi expects // that request from our client). - sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParams], args hiParams) (*CallToolResult, any, error) { + sayHi := func(ctx context.Context, req *CallToolRequest, args hiParams) (*CallToolResult, any, error) { if err := req.Session.Ping(ctx, nil); err == nil { // ping should fail, but not break the connection t.Errorf("ping succeeded unexpectedly") @@ -1179,7 +1179,7 @@ func TestTokenInfo(t *testing.T) { ctx := context.Background() // Create a server with a tool that returns TokenInfo. - tokenInfo := func(ctx context.Context, req *ServerRequest[*CallToolParams], _ struct{}) (*CallToolResult, any, error) { + tokenInfo := func(ctx context.Context, req *CallToolRequest, _ struct{}) (*CallToolResult, any, error) { return &CallToolResult{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.Extra.TokenInfo)}}}, nil, nil } server := NewServer(testImpl, nil) diff --git a/mcp/tool.go b/mcp/tool.go index f0178c23..bd10a07c 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -16,10 +16,10 @@ import ( // A ToolHandler handles a call to tools/call. // [CallToolParams.Arguments] will contain a map[string]any that has been validated // against the input schema. -type ToolHandler func(context.Context, *ServerRequest[*CallToolParams]) (*CallToolResult, error) +type ToolHandler func(context.Context, *CallToolRequest) (*CallToolResult, error) // A ToolHandlerFor handles a call to tools/call with typed arguments and results. -type ToolHandlerFor[In, Out any] func(context.Context, *ServerRequest[*CallToolParams], In) (*CallToolResult, Out, error) +type ToolHandlerFor[In, Out any] func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) // A serverTool is a tool definition that is bound to a tool handler. type serverTool struct { diff --git a/mcp/tool_test.go b/mcp/tool_test.go index 756d6aa4..2722a9ac 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -61,7 +61,7 @@ func TestToolErrorHandling(t *testing.T) { server := NewServer(testImpl, nil) // Create a tool that returns a structured error - structuredErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParams], args map[string]any) (*CallToolResult, any, error) { + structuredErrorHandler := func(ctx context.Context, req *CallToolRequest, args map[string]any) (*CallToolResult, any, error) { return nil, nil, &jsonrpc2.WireError{ Code: CodeInvalidParams, Message: "internal server error", @@ -69,7 +69,7 @@ func TestToolErrorHandling(t *testing.T) { } // Create a tool that returns a regular error - regularErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParams], args map[string]any) (*CallToolResult, any, error) { + regularErrorHandler := func(ctx context.Context, req *CallToolRequest, args map[string]any) (*CallToolResult, any, error) { return nil, nil, fmt.Errorf("tool execution failed") } From d8e18b34cf372c8c4cb5a0c7beea6b7b2b16fcac Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Thu, 21 Aug 2025 14:43:22 -0400 Subject: [PATCH 124/221] mcp: fix bugs handling pointer types (#344) Some bad merging in #338 led to a bad bug in the handling of structured output for pointer output types. Add a conformance test and fix the bug. --- mcp/conformance_test.go | 15 +++++++++++ mcp/server.go | 5 +--- mcp/testdata/conformance/server/tools.txtar | 30 +++++++++++++++++++++ 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/mcp/conformance_test.go b/mcp/conformance_test.go index 9bd8b8f6..a8da4fb7 100644 --- a/mcp/conformance_test.go +++ b/mcp/conformance_test.go @@ -8,6 +8,7 @@ package mcp import ( "bytes" + "context" "encoding/json" "errors" "flag" @@ -96,6 +97,18 @@ func TestServerConformance(t *testing.T) { } } +type input struct { + In string `jsonschema:"the input"` +} + +type output struct { + Out string `jsonschema:"the output"` +} + +func structuredTool(ctx context.Context, req *CallToolRequest, args *input) (*CallToolResult, *output, error) { + return nil, &output{"Ack " + args.In}, nil +} + // runServerTest runs the server conformance test. // It must be executed in a synctest bubble. func runServerTest(t *testing.T, test *conformanceTest) { @@ -109,6 +122,8 @@ func runServerTest(t *testing.T, test *conformanceTest) { Name: "greet", Description: "say hi", }, sayHi) + case "structured": + AddTool(s, &Tool{Name: "structured"}, structuredTool) default: t.Fatalf("unknown tool %q", tn) } diff --git a/mcp/server.go b/mcp/server.go index 7af83824..1ffbbc38 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -206,7 +206,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan ) if reflect.TypeFor[Out]() != reflect.TypeFor[any]() { var err error - elemZero, err = setSchema[Out](&t.OutputSchema, &outputResolved) + elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved) if err != nil { return nil, nil, fmt.Errorf("output schema: %v", err) } @@ -258,9 +258,6 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan res.StructuredContent = elemZero } } - if tt.OutputSchema != nil && elemZero != nil { - res.StructuredContent = elemZero - } return res, nil } diff --git a/mcp/testdata/conformance/server/tools.txtar b/mcp/testdata/conformance/server/tools.txtar index 29dfdc18..870e9ea5 100644 --- a/mcp/testdata/conformance/server/tools.txtar +++ b/mcp/testdata/conformance/server/tools.txtar @@ -8,6 +8,7 @@ Fixed bugs: -- tools -- greet +structured -- client -- { @@ -63,6 +64,35 @@ greet "additionalProperties": false }, "name": "greet" + }, + { + "inputSchema": { + "type": "object", + "required": [ + "In" + ], + "properties": { + "In": { + "type": "string", + "description": "the input" + } + }, + "additionalProperties": false + }, + "name": "structured", + "outputSchema": { + "type": "object", + "required": [ + "Out" + ], + "properties": { + "Out": { + "type": "string", + "description": "the output" + } + }, + "additionalProperties": false + } } ] } From bf79d784b901a64a463c36ebd55b44d2f548085b Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 21 Aug 2025 15:24:58 -0400 Subject: [PATCH 125/221] mcp: remove references to ClientRequest[T] for concrete T (#343) See related PR about ServerRequest[T]. --- mcp/client.go | 32 ++++++++++++++++---------------- mcp/client_test.go | 2 +- mcp/mcp_test.go | 24 ++++++++++++------------ mcp/requests.go | 39 ++++++++++++++++++++++++++------------- mcp/server.go | 4 ++-- mcp/streamable_test.go | 6 +++--- 6 files changed, 60 insertions(+), 47 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 2511c05b..ec1dc456 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -55,14 +55,14 @@ func NewClient(impl *Implementation, opts *ClientOptions) *Client { type ClientOptions struct { // Handler for sampling. // Called when a server calls CreateMessage. - CreateMessageHandler func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) + CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) // Handlers for notifications from the server. - ToolListChangedHandler func(context.Context, *ClientRequest[*ToolListChangedParams]) - PromptListChangedHandler func(context.Context, *ClientRequest[*PromptListChangedParams]) - ResourceListChangedHandler func(context.Context, *ClientRequest[*ResourceListChangedParams]) - ResourceUpdatedHandler func(context.Context, *ClientRequest[*ResourceUpdatedNotificationParams]) - LoggingMessageHandler func(context.Context, *ClientRequest[*LoggingMessageParams]) - ProgressNotificationHandler func(context.Context, *ClientRequest[*ProgressNotificationParams]) + ToolListChangedHandler func(context.Context, *ToolListChangedRequest) + PromptListChangedHandler func(context.Context, *PromptListChangedRequest) + ResourceListChangedHandler func(context.Context, *ResourceListChangedRequest) + ResourceUpdatedHandler func(context.Context, *ResourceUpdatedNotificationRequest) + LoggingMessageHandler func(context.Context, *LoggingMessageRequest) + ProgressNotificationHandler func(context.Context, *ProgressNotificationClientRequest) // If non-zero, defines an interval for regular "ping" requests. // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. @@ -132,7 +132,7 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio ClientInfo: c.impl, Capabilities: c.capabilities(), } - req := &ClientRequest[*InitializeParams]{Session: cs, Params: params} + req := &InitializeRequest{Session: cs, Params: params} res, err := handleSend[*InitializeResult](ctx, methodInitialize, req) if err != nil { _ = cs.Close() @@ -145,7 +145,7 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio if hc, ok := cs.mcpConn.(clientConnection); ok { hc.sessionUpdated(cs.state) } - req2 := &ClientRequest[*InitializedParams]{Session: cs, Params: &InitializedParams{}} + req2 := &InitializedClientRequest{Session: cs, Params: &InitializedParams{}} if err := handleNotify(ctx, notificationInitialized, req2); err != nil { _ = cs.Close() return nil, err @@ -248,7 +248,7 @@ func changeAndNotify[P Params](c *Client, notification string, params P, change notifySessions(sessions, notification, params) } -func (c *Client) listRoots(_ context.Context, req *ClientRequest[*ListRootsParams]) (*ListRootsResult, error) { +func (c *Client) listRoots(_ context.Context, req *ListRootsRequest) (*ListRootsResult, error) { c.mu.Lock() defer c.mu.Unlock() roots := slices.Collect(c.roots.all()) @@ -260,7 +260,7 @@ func (c *Client) listRoots(_ context.Context, req *ClientRequest[*ListRootsParam }, nil } -func (c *Client) createMessage(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { +func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { if c.opts.CreateMessageHandler == nil { // TODO: wrap or annotate this error? Pick a standard code? return nil, jsonrpc2.NewError(CodeUnsupportedMethod, "client does not support CreateMessage") @@ -436,35 +436,35 @@ func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribePar return err } -func (c *Client) callToolChangedHandler(ctx context.Context, req *ClientRequest[*ToolListChangedParams]) (Result, error) { +func (c *Client) callToolChangedHandler(ctx context.Context, req *ToolListChangedRequest) (Result, error) { if h := c.opts.ToolListChangedHandler; h != nil { h(ctx, req) } return nil, nil } -func (c *Client) callPromptChangedHandler(ctx context.Context, req *ClientRequest[*PromptListChangedParams]) (Result, error) { +func (c *Client) callPromptChangedHandler(ctx context.Context, req *PromptListChangedRequest) (Result, error) { if h := c.opts.PromptListChangedHandler; h != nil { h(ctx, req) } return nil, nil } -func (c *Client) callResourceChangedHandler(ctx context.Context, req *ClientRequest[*ResourceListChangedParams]) (Result, error) { +func (c *Client) callResourceChangedHandler(ctx context.Context, req *ResourceListChangedRequest) (Result, error) { if h := c.opts.ResourceListChangedHandler; h != nil { h(ctx, req) } return nil, nil } -func (c *Client) callResourceUpdatedHandler(ctx context.Context, req *ClientRequest[*ResourceUpdatedNotificationParams]) (Result, error) { +func (c *Client) callResourceUpdatedHandler(ctx context.Context, req *ResourceUpdatedNotificationRequest) (Result, error) { if h := c.opts.ResourceUpdatedHandler; h != nil { h(ctx, req) } return nil, nil } -func (c *Client) callLoggingHandler(ctx context.Context, req *ClientRequest[*LoggingMessageParams]) (Result, error) { +func (c *Client) callLoggingHandler(ctx context.Context, req *LoggingMessageRequest) (Result, error) { if h := c.opts.LoggingMessageHandler; h != nil { h(ctx, req) } diff --git a/mcp/client_test.go b/mcp/client_test.go index 469fa3fb..eaeedc81 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -211,7 +211,7 @@ func TestClientCapabilities(t *testing.T) { name: "With sampling", configureClient: func(s *Client) {}, clientOpts: ClientOptions{ - CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { return nil, nil }, }, diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 42aa06af..9c578392 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -74,13 +74,13 @@ func TestEndToEnd(t *testing.T) { } sopts := &ServerOptions{ - InitializedHandler: func(context.Context, *InitializedRequest) { + InitializedHandler: func(context.Context, *InitializedServerRequest) { notificationChans["initialized"] <- 0 }, RootsListChangedHandler: func(context.Context, *RootsListChangedRequest) { notificationChans["roots"] <- 0 }, - ProgressNotificationHandler: func(context.Context, *ProgressNotificationRequest) { + ProgressNotificationHandler: func(context.Context, *ProgressNotificationServerRequest) { notificationChans["progress_server"] <- 0 }, SubscribeHandler: func(context.Context, *SubscribeRequest) error { @@ -129,25 +129,25 @@ func TestEndToEnd(t *testing.T) { loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging opts := &ClientOptions{ - CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil }, - ToolListChangedHandler: func(context.Context, *ClientRequest[*ToolListChangedParams]) { + ToolListChangedHandler: func(context.Context, *ToolListChangedRequest) { notificationChans["tools"] <- 0 }, - PromptListChangedHandler: func(context.Context, *ClientRequest[*PromptListChangedParams]) { + PromptListChangedHandler: func(context.Context, *PromptListChangedRequest) { notificationChans["prompts"] <- 0 }, - ResourceListChangedHandler: func(context.Context, *ClientRequest[*ResourceListChangedParams]) { + ResourceListChangedHandler: func(context.Context, *ResourceListChangedRequest) { notificationChans["resources"] <- 0 }, - LoggingMessageHandler: func(_ context.Context, req *ClientRequest[*LoggingMessageParams]) { + LoggingMessageHandler: func(_ context.Context, req *LoggingMessageRequest) { loggingMessages <- req.Params }, - ProgressNotificationHandler: func(context.Context, *ClientRequest[*ProgressNotificationParams]) { + ProgressNotificationHandler: func(context.Context, *ProgressNotificationClientRequest) { notificationChans["progress_client"] <- 0 }, - ResourceUpdatedHandler: func(context.Context, *ClientRequest[*ResourceUpdatedNotificationParams]) { + ResourceUpdatedHandler: func(context.Context, *ResourceUpdatedNotificationRequest) { notificationChans["resource_updated"] <- 0 }, } @@ -992,10 +992,10 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { func TestSynchronousNotifications(t *testing.T) { var toolsChanged atomic.Bool clientOpts := &ClientOptions{ - ToolListChangedHandler: func(ctx context.Context, req *ClientRequest[*ToolListChangedParams]) { + ToolListChangedHandler: func(ctx context.Context, req *ToolListChangedRequest) { toolsChanged.Store(true) }, - CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + CreateMessageHandler: func(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { if !toolsChanged.Load() { return nil, fmt.Errorf("didn't get a tools changed notification") } @@ -1057,7 +1057,7 @@ func TestNoDistributedDeadlock(t *testing.T) { // possible, and in any case making tool calls asynchronous by default // delegates synchronization to the user. clientOpts := &ClientOptions{ - CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + CreateMessageHandler: func(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { req.Session.CallTool(ctx, &CallToolParams{Name: "tool2"}) return &CreateMessageResult{Content: &TextContent{}}, nil }, diff --git a/mcp/requests.go b/mcp/requests.go index ceed5026..46ff4f8d 100644 --- a/mcp/requests.go +++ b/mcp/requests.go @@ -8,17 +8,30 @@ package mcp // TODO: expand the aliases type ( - CallToolRequest = ServerRequest[*CallToolParams] - CompleteRequest = ServerRequest[*CompleteParams] - GetPromptRequest = ServerRequest[*GetPromptParams] - InitializedRequest = ServerRequest[*InitializedParams] - ListPromptsRequest = ServerRequest[*ListPromptsParams] - ListResourcesRequest = ServerRequest[*ListResourcesParams] - ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams] - ListToolsRequest = ServerRequest[*ListToolsParams] - ProgressNotificationRequest = ServerRequest[*ProgressNotificationParams] - ReadResourceRequest = ServerRequest[*ReadResourceParams] - RootsListChangedRequest = ServerRequest[*RootsListChangedParams] - SubscribeRequest = ServerRequest[*SubscribeParams] - UnsubscribeRequest = ServerRequest[*UnsubscribeParams] + CallToolRequest = ServerRequest[*CallToolParams] + CompleteRequest = ServerRequest[*CompleteParams] + GetPromptRequest = ServerRequest[*GetPromptParams] + InitializedServerRequest = ServerRequest[*InitializedParams] + ListPromptsRequest = ServerRequest[*ListPromptsParams] + ListResourcesRequest = ServerRequest[*ListResourcesParams] + ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams] + ListToolsRequest = ServerRequest[*ListToolsParams] + ProgressNotificationServerRequest = ServerRequest[*ProgressNotificationParams] + ReadResourceRequest = ServerRequest[*ReadResourceParams] + RootsListChangedRequest = ServerRequest[*RootsListChangedParams] + SubscribeRequest = ServerRequest[*SubscribeParams] + UnsubscribeRequest = ServerRequest[*UnsubscribeParams] +) + +type ( + CreateMessageRequest = ClientRequest[*CreateMessageParams] + InitializedClientRequest = ClientRequest[*InitializedParams] + InitializeRequest = ClientRequest[*InitializeParams] + ListRootsRequest = ClientRequest[*ListRootsParams] + LoggingMessageRequest = ClientRequest[*LoggingMessageParams] + ProgressNotificationClientRequest = ClientRequest[*ProgressNotificationParams] + PromptListChangedRequest = ClientRequest[*PromptListChangedParams] + ResourceListChangedRequest = ClientRequest[*ResourceListChangedParams] + ResourceUpdatedNotificationRequest = ClientRequest[*ResourceUpdatedNotificationParams] + ToolListChangedRequest = ClientRequest[*ToolListChangedParams] ) diff --git a/mcp/server.go b/mcp/server.go index 1ffbbc38..f9ddcd66 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -54,14 +54,14 @@ type ServerOptions struct { // Optional instructions for connected clients. Instructions string // If non-nil, called when "notifications/initialized" is received. - InitializedHandler func(context.Context, *InitializedRequest) + InitializedHandler func(context.Context, *InitializedServerRequest) // PageSize is the maximum number of items to return in a single page for // list methods (e.g. ListTools). PageSize int // If non-nil, called when "notifications/roots/list_changed" is received. RootsListChangedHandler func(context.Context, *RootsListChangedRequest) // If non-nil, called when "notifications/progress" is received. - ProgressNotificationHandler func(context.Context, *ProgressNotificationRequest) + ProgressNotificationHandler func(context.Context, *ProgressNotificationServerRequest) // If non-nil, called when "completion/complete" is received. CompletionHandler func(context.Context, *CompleteRequest) (*CompleteResult, error) // If non-zero, defines an interval for regular "ping" requests. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 5cd04eca..603be473 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -121,7 +121,7 @@ func TestStreamableTransports(t *testing.T) { HTTPClient: httpClient, } client := NewClient(testImpl, &ClientOptions{ - CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil }, }) @@ -255,7 +255,7 @@ func testClientReplay(t *testing.T, test clientReplayTest) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() client := NewClient(testImpl, &ClientOptions{ - ProgressNotificationHandler: func(ctx context.Context, req *ClientRequest[*ProgressNotificationParams]) { + ProgressNotificationHandler: func(ctx context.Context, req *ProgressNotificationClientRequest) { notifications <- req.Params.Message }, }) @@ -344,7 +344,7 @@ func TestServerInitiatedSSE(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() client := NewClient(testImpl, &ClientOptions{ - ToolListChangedHandler: func(context.Context, *ClientRequest[*ToolListChangedParams]) { + ToolListChangedHandler: func(context.Context, *ToolListChangedRequest) { notifications <- "toolListChanged" }, }) From dd8e3af18d2c45d13f25298fc74fd3e2ca085da5 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Thu, 21 Aug 2025 17:41:38 -0400 Subject: [PATCH 126/221] README: soften the warning, and reference roadmap (#345) We can warn a little less loudly against using the SDK. --- README.md | 8 +++----- internal/readme/README.src.md | 8 +++----- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 4da0ac61..299245cc 100644 --- a/README.md +++ b/README.md @@ -15,11 +15,9 @@ This repository contains an unreleased implementation of the official Go software development kit (SDK) for the Model Context Protocol (MCP). > [!WARNING] -> The SDK should be considered unreleased, and is currently unstable -> and subject to breaking changes. Please test it out and file bug reports or API -> proposals, but don't use it in real projects. See the issue tracker for known -> issues and missing features. We aim to release a stable version of the SDK in -> August, 2025. +> The SDK is not yet at v1.0.0 and may still be subject to incompatible API +> changes. We aim to tag v1.0.0 in September, 2025. See +> https://github.com/modelcontextprotocol/go-sdk/issues/328 for details. ## Design diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index bf9faa26..de5dd48a 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -14,11 +14,9 @@ This repository contains an unreleased implementation of the official Go software development kit (SDK) for the Model Context Protocol (MCP). > [!WARNING] -> The SDK should be considered unreleased, and is currently unstable -> and subject to breaking changes. Please test it out and file bug reports or API -> proposals, but don't use it in real projects. See the issue tracker for known -> issues and missing features. We aim to release a stable version of the SDK in -> August, 2025. +> The SDK is not yet at v1.0.0 and may still be subject to incompatible API +> changes. We aim to tag v1.0.0 in September, 2025. See +> https://github.com/modelcontextprotocol/go-sdk/issues/328 for details. ## Design From 62db9140828544a8229f858b840f4d3e8d0844db Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 21 Aug 2025 17:54:17 -0400 Subject: [PATCH 127/221] mcp: export ToolFor (#347) This is needed so clients can modify tool schemas, and not merely to wrap ToolHandlers. (A consequence of returning a copy of Tool instead of modifying it.) --- mcp/requests.go | 1 - mcp/server.go | 10 +++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/mcp/requests.go b/mcp/requests.go index 46ff4f8d..52b2039d 100644 --- a/mcp/requests.go +++ b/mcp/requests.go @@ -6,7 +6,6 @@ package mcp -// TODO: expand the aliases type ( CallToolRequest = ServerRequest[*CallToolParams] CompleteRequest = ServerRequest[*CompleteParams] diff --git a/mcp/server.go b/mcp/server.go index f9ddcd66..740b2b9d 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -171,15 +171,15 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { func() bool { s.tools.add(st); return true }) } -// toolFor returns a shallow copy of t and a [ToolHandler] that wraps h. +// ToolFor returns a shallow copy of t and a [ToolHandler] that wraps h. // If the tool's input schema is nil, it is set to the schema inferred from the In // type parameter, using [jsonschema.For]. // If the tool's output schema is nil and the Out type parameter is not the empty // interface, then the output schema is set to the schema inferred from Out. // -// Most users will call [AddTool]. Use [toolFor] if you wish to wrap the ToolHandler -// before calling [Server.AddTool]. -func toolFor[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler) { +// Most users will call [AddTool]. Use [ToolFor] if you wish to modify the tool's +// schemas or wrap the ToolHandler before calling [Server.AddTool]. +func ToolFor[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler) { tt, hh, err := toolForErr(t, h) if err != nil { panic(fmt.Sprintf("ToolFor: tool %q: %v", t.Name, err)) @@ -303,7 +303,7 @@ func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) // If the tool's output schema is nil and the Out type parameter is not the empty // interface, then the copy's output schema is set to the schema inferred from Out. func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { - s.AddTool(toolFor(t, h)) + s.AddTool(ToolFor(t, h)) } // RemoveTools removes the tools with the given names. From c058c6a93ba6f32d7b0ca6ca5edcc6d384368e1d Mon Sep 17 00:00:00 2001 From: Kartik Verma Date: Fri, 22 Aug 2025 18:59:51 +0530 Subject: [PATCH 128/221] feat/Issue-13: elicitation support (#188) fixes: #13 Implements elicitation functionality from MCP 2025-06-18 specification. Changes - Add ElicitParams and ElicitResult protocol types - Add ServerSession.Elicit() method and ClientOptions.ElicitationHandler - Schema validation enforces top-level properties only - Support for accept/decline/cancel actions with progress tokens Tests - Integration test in TestEndToEnd - Schema validation, error handling, and capability declaration tests --- mcp/client.go | 171 ++++++++++ mcp/elicitation_example_test.go | 86 +++++ mcp/mcp_test.go | 540 ++++++++++++++++++++++++++++++++ mcp/protocol.go | 34 +- mcp/requests.go | 1 + mcp/server.go | 5 + 6 files changed, 836 insertions(+), 1 deletion(-) create mode 100644 mcp/elicitation_example_test.go diff --git a/mcp/client.go b/mcp/client.go index ec1dc456..6df12bb0 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -6,12 +6,14 @@ package mcp import ( "context" + "encoding/json" "fmt" "iter" "slices" "sync" "time" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -56,6 +58,9 @@ type ClientOptions struct { // Handler for sampling. // Called when a server calls CreateMessage. CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) + // Handler for elicitation. + // Called when a server requests user input via Elicit. + ElicitationHandler func(context.Context, *ElicitRequest) (*ElicitResult, error) // Handlers for notifications from the server. ToolListChangedHandler func(context.Context, *ToolListChangedRequest) PromptListChangedHandler func(context.Context, *PromptListChangedRequest) @@ -111,6 +116,9 @@ func (c *Client) capabilities() *ClientCapabilities { if c.opts.CreateMessageHandler != nil { caps.Sampling = &SamplingCapabilities{} } + if c.opts.ElicitationHandler != nil { + caps.Elicitation = &ElicitationCapabilities{} + } return caps } @@ -268,6 +276,168 @@ func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) ( return c.opts.CreateMessageHandler(ctx, req) } +func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + if c.opts.ElicitationHandler == nil { + // TODO: wrap or annotate this error? Pick a standard code? + return nil, jsonrpc2.NewError(CodeUnsupportedMethod, "client does not support elicitation") + } + + // Validate that the requested schema only contains top-level properties without nesting + if err := validateElicitSchema(req.Params.RequestedSchema); err != nil { + return nil, jsonrpc2.NewError(CodeInvalidParams, err.Error()) + } + + res, err := c.opts.ElicitationHandler(ctx, req) + if err != nil { + return nil, err + } + + // Validate elicitation result content against requested schema + if req.Params.RequestedSchema != nil && res.Content != nil { + resolved, err := req.Params.RequestedSchema.Resolve(nil) + if err != nil { + return nil, jsonrpc2.NewError(CodeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err)) + } + + if err := resolved.Validate(res.Content); err != nil { + return nil, jsonrpc2.NewError(CodeInvalidParams, fmt.Sprintf("elicitation result content does not match requested schema: %v", err)) + } + } + + return res, nil +} + +// validateElicitSchema validates that the schema conforms to MCP elicitation schema requirements. +// Per the MCP specification, elicitation schemas are limited to flat objects with primitive properties only. +func validateElicitSchema(schema *jsonschema.Schema) error { + if schema == nil { + return nil // nil schema is allowed + } + + // The root schema must be of type "object" if specified + if schema.Type != "" && schema.Type != "object" { + return fmt.Errorf("elicit schema must be of type 'object', got %q", schema.Type) + } + + // Check if the schema has properties + if schema.Properties != nil { + for propName, propSchema := range schema.Properties { + if propSchema == nil { + continue + } + + if err := validateElicitProperty(propName, propSchema); err != nil { + return err + } + } + } + + return nil +} + +// validateElicitProperty validates a single property in an elicitation schema. +func validateElicitProperty(propName string, propSchema *jsonschema.Schema) error { + // Check if this property has nested properties (not allowed) + if len(propSchema.Properties) > 0 { + return fmt.Errorf("elicit schema property %q contains nested properties, only primitive properties are allowed", propName) + } + + // Validate based on the property type - only primitives are supported + switch propSchema.Type { + case "string": + return validateElicitStringProperty(propName, propSchema) + case "number", "integer": + return validateElicitNumberProperty(propName, propSchema) + case "boolean": + return validateElicitBooleanProperty(propName, propSchema) + default: + return fmt.Errorf("elicit schema property %q has unsupported type %q, only string, number, integer, and boolean are allowed", propName, propSchema.Type) + } +} + +// validateElicitStringProperty validates string-type properties, including enums. +func validateElicitStringProperty(propName string, propSchema *jsonschema.Schema) error { + // Handle enum validation (enums are a special case of strings) + if len(propSchema.Enum) > 0 { + // Enums must be string type (or untyped which defaults to string) + if propSchema.Type != "" && propSchema.Type != "string" { + return fmt.Errorf("elicit schema property %q has enum values but type is %q, enums are only supported for string type", propName, propSchema.Type) + } + // Enum values themselves are validated by the JSON schema library + // Validate enumNames if present - must match enum length + if propSchema.Extra != nil { + if enumNamesRaw, exists := propSchema.Extra["enumNames"]; exists { + // Type check enumNames - should be a slice + if enumNamesSlice, ok := enumNamesRaw.([]interface{}); ok { + if len(enumNamesSlice) != len(propSchema.Enum) { + return fmt.Errorf("elicit schema property %q has %d enum values but %d enumNames, they must match", propName, len(propSchema.Enum), len(enumNamesSlice)) + } + } else { + return fmt.Errorf("elicit schema property %q has invalid enumNames type, must be an array", propName) + } + } + } + return nil + } + + // Validate format if specified - only specific formats are allowed + if propSchema.Format != "" { + allowedFormats := map[string]bool{ + "email": true, + "uri": true, + "date": true, + "date-time": true, + } + if !allowedFormats[propSchema.Format] { + return fmt.Errorf("elicit schema property %q has unsupported format %q, only email, uri, date, and date-time are allowed", propName, propSchema.Format) + } + } + + // Validate minLength constraint if specified + if propSchema.MinLength != nil { + if *propSchema.MinLength < 0 { + return fmt.Errorf("elicit schema property %q has invalid minLength %d, must be non-negative", propName, *propSchema.MinLength) + } + } + + // Validate maxLength constraint if specified + if propSchema.MaxLength != nil { + if *propSchema.MaxLength < 0 { + return fmt.Errorf("elicit schema property %q has invalid maxLength %d, must be non-negative", propName, *propSchema.MaxLength) + } + // Check that maxLength >= minLength if both are specified + if propSchema.MinLength != nil && *propSchema.MaxLength < *propSchema.MinLength { + return fmt.Errorf("elicit schema property %q has maxLength %d less than minLength %d", propName, *propSchema.MaxLength, *propSchema.MinLength) + } + } + + return nil +} + +// validateElicitNumberProperty validates number and integer-type properties. +func validateElicitNumberProperty(propName string, propSchema *jsonschema.Schema) error { + if propSchema.Minimum != nil && propSchema.Maximum != nil { + if *propSchema.Maximum < *propSchema.Minimum { + return fmt.Errorf("elicit schema property %q has maximum %g less than minimum %g", propName, *propSchema.Maximum, *propSchema.Minimum) + } + } + + return nil +} + +// validateElicitBooleanProperty validates boolean-type properties. +func validateElicitBooleanProperty(propName string, propSchema *jsonschema.Schema) error { + // Validate default value if specified - must be a valid boolean + if propSchema.Default != nil { + var defaultValue bool + if err := json.Unmarshal(propSchema.Default, &defaultValue); err != nil { + return fmt.Errorf("elicit schema property %q has invalid default value, must be a boolean: %v", propName, err) + } + } + + return nil +} + // AddSendingMiddleware wraps the current sending method handler using the provided // middleware. Middleware is applied from right to left, so that the first one is // executed first. @@ -308,6 +478,7 @@ var clientMethodInfos = map[string]methodInfo{ methodPing: newClientMethodInfo(clientSessionMethod((*ClientSession).ping), missingParamsOK), methodListRoots: newClientMethodInfo(clientMethod((*Client).listRoots), missingParamsOK), methodCreateMessage: newClientMethodInfo(clientMethod((*Client).createMessage), 0), + methodElicit: newClientMethodInfo(clientMethod((*Client).elicit), missingParamsOK), notificationCancelled: newClientMethodInfo(clientSessionMethod((*ClientSession).cancel), notification|missingParamsOK), notificationToolListChanged: newClientMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK), notificationPromptListChanged: newClientMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK), diff --git a/mcp/elicitation_example_test.go b/mcp/elicitation_example_test.go new file mode 100644 index 00000000..526a4881 --- /dev/null +++ b/mcp/elicitation_example_test.go @@ -0,0 +1,86 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "context" + "fmt" + "log" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func Example_elicitation() { + ctx := context.Background() + clientTransport, serverTransport := mcp.NewInMemoryTransports() + + // Create server + server := mcp.NewServer(&mcp.Implementation{Name: "config-server", Version: "v1.0.0"}, nil) + + serverSession, err := server.Connect(ctx, serverTransport, nil) + if err != nil { + log.Fatal(err) + } + + // Create client with elicitation handler + // Note: Never use elicitation for sensitive data like API keys or passwords + client := mcp.NewClient(&mcp.Implementation{Name: "config-client", Version: "v1.0.0"}, &mcp.ClientOptions{ + ElicitationHandler: func(ctx context.Context, request *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + fmt.Printf("Server requests: %s\n", request.Params.Message) + + // In a real application, this would prompt the user for input + // Here we simulate user providing configuration data + return &mcp.ElicitResult{ + Action: "accept", + Content: map[string]any{ + "serverEndpoint": "https://api.example.com", + "maxRetries": float64(3), + "enableLogs": true, + }, + }, nil + }, + }) + + _, err = client.Connect(ctx, clientTransport, nil) + if err != nil { + log.Fatal(err) + } + + // Server requests user configuration via elicitation + configSchema := &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "serverEndpoint": {Type: "string", Description: "Server endpoint URL"}, + "maxRetries": {Type: "number", Minimum: ptr(1.0), Maximum: ptr(10.0)}, + "enableLogs": {Type: "boolean", Description: "Enable debug logging"}, + }, + Required: []string{"serverEndpoint"}, + } + + result, err := serverSession.Elicit(ctx, &mcp.ElicitParams{ + Message: "Please provide your configuration settings", + RequestedSchema: configSchema, + }) + if err != nil { + log.Fatal(err) + } + + if result.Action == "accept" { + fmt.Printf("Configuration received: Endpoint: %v, Max Retries: %.0f, Logs: %v\n", + result.Content["serverEndpoint"], + result.Content["maxRetries"], + result.Content["enableLogs"]) + } + + // Output: + // Server requests: Please provide your configuration settings + // Configuration received: Endpoint: https://api.example.com, Max Retries: 3, Logs: true +} + +// ptr is a helper function to create pointers for schema constraints +func ptr[T any](v T) *T { + return &v +} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 9c578392..fef0c91e 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -7,6 +7,7 @@ package mcp import ( "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -132,6 +133,15 @@ func TestEndToEnd(t *testing.T) { CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil }, + ElicitationHandler: func(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{ + Action: "accept", + Content: map[string]any{ + "name": "Test User", + "email": "test@example.com", + }, + }, nil + }, ToolListChangedHandler: func(context.Context, *ToolListChangedRequest) { notificationChans["tools"] <- 0 }, @@ -474,6 +484,19 @@ func TestEndToEnd(t *testing.T) { } }) + t.Run("elicitation", func(t *testing.T) { + result, err := ss.Elicit(ctx, &ElicitParams{ + Message: "Please provide information", + }) + if err != nil { + t.Fatal(err) + } + if result.Action != "accept" { + t.Errorf("got action %q, want %q", result.Action, "accept") + } + + }) + // Disconnect. cs.Close() clientWG.Wait() @@ -906,6 +929,518 @@ func TestKeepAlive(t *testing.T) { } } +func TestElicitationUnsupportedMethod(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + // Server + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + // Client without ElicitationHandler + c := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil + }, + }) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Test that elicitation fails when no handler is provided + _, err = ss.Elicit(ctx, &ElicitParams{ + Message: "This should fail", + RequestedSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "test": {Type: "string"}, + }, + }, + }) + + if err == nil { + t.Error("expected error when ElicitationHandler is not provided, got nil") + } + if code := errorCode(err); code != CodeUnsupportedMethod { + t.Errorf("got error code %d, want %d (CodeUnsupportedMethod)", code, CodeUnsupportedMethod) + } + if !strings.Contains(err.Error(), "does not support elicitation") { + t.Errorf("error should mention unsupported elicitation, got: %v", err) + } +} + +func TestElicitationSchemaValidation(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + c := NewClient(testImpl, &ClientOptions{ + ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept", Content: map[string]any{"test": "value"}}, nil + }, + }) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Test valid schemas - these should not return errors + validSchemas := []struct { + name string + schema *jsonschema.Schema + }{ + { + name: "nil schema", + schema: nil, + }, + { + name: "empty object schema", + schema: &jsonschema.Schema{ + Type: "object", + }, + }, + { + name: "simple string property", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string"}, + }, + }, + }, + { + name: "string with valid formats", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "email": {Type: "string", Format: "email"}, + "website": {Type: "string", Format: "uri"}, + "birthday": {Type: "string", Format: "date"}, + "created": {Type: "string", Format: "date-time"}, + }, + }, + }, + { + name: "string with constraints", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string", MinLength: ptr(1), MaxLength: ptr(100)}, + }, + }, + }, + { + name: "number with constraints", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "age": {Type: "integer", Minimum: ptr(0.0), Maximum: ptr(150.0)}, + "score": {Type: "number", Minimum: ptr(0.0), Maximum: ptr(100.0)}, + }, + }, + }, + { + name: "boolean with default", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "enabled": {Type: "boolean", Default: json.RawMessage("true")}, + }, + }, + }, + { + name: "string enum", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "status": { + Type: "string", + Enum: []any{ + "active", + "inactive", + "pending", + }, + }, + }, + }, + }, + { + name: "enum with matching enumNames", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "priority": { + Type: "string", + Enum: []any{ + "high", + "medium", + "low", + }, + Extra: map[string]any{ + "enumNames": []interface{}{"High Priority", "Medium Priority", "Low Priority"}, + }, + }, + }, + }, + }, + } + + for _, tc := range validSchemas { + t.Run("valid_"+tc.name, func(t *testing.T) { + _, err := ss.Elicit(ctx, &ElicitParams{ + Message: "Test valid schema: " + tc.name, + RequestedSchema: tc.schema, + }) + if err != nil { + t.Errorf("expected no error for valid schema %q, got: %v", tc.name, err) + } + }) + } + + // Test invalid schemas - these should return errors + invalidSchemas := []struct { + name string + schema *jsonschema.Schema + expectedError string + }{ + { + name: "root schema non-object type", + schema: &jsonschema.Schema{ + Type: "string", + }, + expectedError: "elicit schema must be of type 'object', got \"string\"", + }, + { + name: "nested object property", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "user": { + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string"}, + }, + }, + }, + }, + expectedError: "elicit schema property \"user\" contains nested properties, only primitive properties are allowed", + }, + { + name: "property with explicit object type", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "config": {Type: "object"}, + }, + }, + expectedError: "elicit schema property \"config\" has unsupported type \"object\", only string, number, integer, and boolean are allowed", + }, + { + name: "array property", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "tags": {Type: "array", Items: &jsonschema.Schema{Type: "string"}}, + }, + }, + expectedError: "elicit schema property \"tags\" has unsupported type \"array\", only string, number, integer, and boolean are allowed", + }, + { + name: "array without items", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "items": {Type: "array"}, + }, + }, + expectedError: "elicit schema property \"items\" has unsupported type \"array\", only string, number, integer, and boolean are allowed", + }, + { + name: "unsupported string format", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "phone": {Type: "string", Format: "phone"}, + }, + }, + expectedError: "elicit schema property \"phone\" has unsupported format \"phone\", only email, uri, date, and date-time are allowed", + }, + { + name: "unsupported type", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "data": {Type: "null"}, + }, + }, + expectedError: "elicit schema property \"data\" has unsupported type \"null\"", + }, + { + name: "string with invalid minLength", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string", MinLength: ptr(-1)}, + }, + }, + expectedError: "elicit schema property \"name\" has invalid minLength -1, must be non-negative", + }, + { + name: "string with invalid maxLength", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string", MaxLength: ptr(-5)}, + }, + }, + expectedError: "elicit schema property \"name\" has invalid maxLength -5, must be non-negative", + }, + { + name: "string with maxLength less than minLength", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string", MinLength: ptr(10), MaxLength: ptr(5)}, + }, + }, + expectedError: "elicit schema property \"name\" has maxLength 5 less than minLength 10", + }, + { + name: "number with maximum less than minimum", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "score": {Type: "number", Minimum: ptr(100.0), Maximum: ptr(50.0)}, + }, + }, + expectedError: "elicit schema property \"score\" has maximum 50 less than minimum 100", + }, + { + name: "boolean with invalid default", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "enabled": {Type: "boolean", Default: json.RawMessage(`"not-a-boolean"`)}, + }, + }, + expectedError: "elicit schema property \"enabled\" has invalid default value, must be a boolean", + }, + { + name: "enum with mismatched enumNames length", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "priority": { + Type: "string", + Enum: []any{ + "high", + "medium", + "low", + }, + Extra: map[string]any{ + "enumNames": []interface{}{"High Priority", "Medium Priority"}, // Only 2 names for 3 values + }, + }, + }, + }, + expectedError: "elicit schema property \"priority\" has 3 enum values but 2 enumNames, they must match", + }, + { + name: "enum with invalid enumNames type", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "status": { + Type: "string", + Enum: []any{ + "active", + "inactive", + }, + Extra: map[string]any{ + "enumNames": "not an array", // Should be array + }, + }, + }, + }, + expectedError: "elicit schema property \"status\" has invalid enumNames type, must be an array", + }, + { + name: "enum without explicit type", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "priority": { + Enum: []any{ + "high", + "medium", + "low", + }, + }, + }, + }, + expectedError: "elicit schema property \"priority\" has unsupported type \"\", only string, number, integer, and boolean are allowed", + }, + { + name: "untyped property", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "data": {}, + }, + }, + expectedError: "elicit schema property \"data\" has unsupported type \"\", only string, number, integer, and boolean are allowed", + }, + } + + for _, tc := range invalidSchemas { + t.Run("invalid_"+tc.name, func(t *testing.T) { + _, err := ss.Elicit(ctx, &ElicitParams{ + Message: "Test invalid schema: " + tc.name, + RequestedSchema: tc.schema, + }) + if err == nil { + t.Errorf("expected error for invalid schema %q, got nil", tc.name) + return + } + if code := errorCode(err); code != CodeInvalidParams { + t.Errorf("got error code %d, want %d (CodeInvalidParams)", code, CodeInvalidParams) + } + if !strings.Contains(err.Error(), tc.expectedError) { + t.Errorf("error message %q does not contain expected text %q", err.Error(), tc.expectedError) + } + }) + } +} + +func TestElicitationProgressToken(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + c := NewClient(testImpl, &ClientOptions{ + ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + params := &ElicitParams{ + Message: "Test progress token", + Meta: Meta{}, + } + params.SetProgressToken("test-token") + + if token := params.GetProgressToken(); token != "test-token" { + t.Errorf("got progress token %v, want %q", token, "test-token") + } + + _, err = ss.Elicit(ctx, params) + if err != nil { + t.Fatal(err) + } +} + +func TestElicitationCapabilityDeclaration(t *testing.T) { + ctx := context.Background() + + t.Run("with_handler", func(t *testing.T) { + ct, st := NewInMemoryTransports() + + // Client with ElicitationHandler should declare capability + c := NewClient(testImpl, &ClientOptions{ + ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "cancel"}, nil + }, + }) + + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // The client should have declared elicitation capability during initialization + // We can verify this worked by successfully making an elicitation call + result, err := ss.Elicit(ctx, &ElicitParams{ + Message: "Test capability", + RequestedSchema: &jsonschema.Schema{Type: "object"}, + }) + if err != nil { + t.Errorf("elicitation should work when capability is declared, got error: %v", err) + } + if result.Action != "cancel" { + t.Errorf("got action %q, want %q", result.Action, "cancel") + } + }) + + t.Run("without_handler", func(t *testing.T) { + ct, st := NewInMemoryTransports() + + // Client without ElicitationHandler should not declare capability + c := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil + }, + }) + + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Elicitation should fail with UnsupportedMethod + _, err = ss.Elicit(ctx, &ElicitParams{ + Message: "This should fail", + RequestedSchema: &jsonschema.Schema{Type: "object"}, + }) + + if err == nil { + t.Error("expected UnsupportedMethod error when no capability declared") + } + if code := errorCode(err); code != CodeUnsupportedMethod { + t.Errorf("got error code %d, want %d (CodeUnsupportedMethod)", code, CodeUnsupportedMethod) + } + }) +} + func TestKeepAliveFailure(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -1185,3 +1720,8 @@ func TestPointerArgEquivalence(t *testing.T) { }) } } + +// ptr is a helper function to create pointers for schema constraints +func ptr[T any](v T) *T { + return &v +} diff --git a/mcp/protocol.go b/mcp/protocol.go index 75db7613..382f745f 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -968,7 +968,39 @@ func (*ResourceUpdatedNotificationParams) isParams() {} // TODO(jba): add CompleteRequest and related types. -// TODO(jba): add ElicitRequest and related types. +// A request from the server to elicit additional information from the user via the client. +type ElicitParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The message to present to the user. + Message string `json:"message"` + // A restricted subset of JSON Schema. + // Only top-level properties are allowed, without nesting. + RequestedSchema *jsonschema.Schema `json:"requestedSchema"` +} + +func (x *ElicitParams) isParams() {} + +func (x *ElicitParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ElicitParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The client's response to an elicitation/create request from the server. +type ElicitResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The user action in response to the elicitation. + // - "accept": User submitted the form/confirmed the action + // - "decline": User explicitly declined the action + // - "cancel": User dismissed without making an explicit choice + Action string `json:"action"` + // The submitted form data, only present when action is "accept". + // Contains values matching the requested schema. + Content map[string]any `json:"content,omitempty"` +} + +func (*ElicitResult) isResult() {} // An Implementation describes the name and version of an MCP implementation, with an optional // title for UI representation. diff --git a/mcp/requests.go b/mcp/requests.go index 52b2039d..5c2c98d0 100644 --- a/mcp/requests.go +++ b/mcp/requests.go @@ -24,6 +24,7 @@ type ( type ( CreateMessageRequest = ClientRequest[*CreateMessageParams] + ElicitRequest = ClientRequest[*ElicitParams] InitializedClientRequest = ClientRequest[*InitializedParams] InitializeRequest = ClientRequest[*InitializeParams] ListRootsRequest = ClientRequest[*ListRootsParams] diff --git a/mcp/server.go b/mcp/server.go index 740b2b9d..571c830f 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -834,6 +834,11 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag return handleSend[*CreateMessageResult](ctx, methodCreateMessage, newServerRequest(ss, orZero[Params](params))) } +// Elicit sends an elicitation request to the client asking for user input. +func (ss *ServerSession) Elicit(ctx context.Context, params *ElicitParams) (*ElicitResult, error) { + return handleSend[*ElicitResult](ctx, methodElicit, newServerRequest(ss, orZero[Params](params))) +} + // Log sends a log message to the client. // The message is not sent if the client has not called SetLevel, or if its level // is below that of the last SetLevel. From 5fd06ae8b74e4b5e801d3104f9ce2d921c908f22 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 22 Aug 2025 14:20:58 +0000 Subject: [PATCH 129/221] mcp: unexport InitializeClientRequest This request is synthetic, and should not be observed by the user. Unexport it and rename InitializeServerRequest to just InitializeRequest. --- mcp/client.go | 2 +- mcp/mcp_test.go | 2 +- mcp/requests.go | 4 ++-- mcp/server.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 6df12bb0..3b1741b3 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -153,7 +153,7 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio if hc, ok := cs.mcpConn.(clientConnection); ok { hc.sessionUpdated(cs.state) } - req2 := &InitializedClientRequest{Session: cs, Params: &InitializedParams{}} + req2 := &initializedClientRequest{Session: cs, Params: &InitializedParams{}} if err := handleNotify(ctx, notificationInitialized, req2); err != nil { _ = cs.Close() return nil, err diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index fef0c91e..52df1479 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -75,7 +75,7 @@ func TestEndToEnd(t *testing.T) { } sopts := &ServerOptions{ - InitializedHandler: func(context.Context, *InitializedServerRequest) { + InitializedHandler: func(context.Context, *InitializedRequest) { notificationChans["initialized"] <- 0 }, RootsListChangedHandler: func(context.Context, *RootsListChangedRequest) { diff --git a/mcp/requests.go b/mcp/requests.go index 5c2c98d0..3afaac5e 100644 --- a/mcp/requests.go +++ b/mcp/requests.go @@ -10,7 +10,7 @@ type ( CallToolRequest = ServerRequest[*CallToolParams] CompleteRequest = ServerRequest[*CompleteParams] GetPromptRequest = ServerRequest[*GetPromptParams] - InitializedServerRequest = ServerRequest[*InitializedParams] + InitializedRequest = ServerRequest[*InitializedParams] ListPromptsRequest = ServerRequest[*ListPromptsParams] ListResourcesRequest = ServerRequest[*ListResourcesParams] ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams] @@ -25,7 +25,7 @@ type ( type ( CreateMessageRequest = ClientRequest[*CreateMessageParams] ElicitRequest = ClientRequest[*ElicitParams] - InitializedClientRequest = ClientRequest[*InitializedParams] + initializedClientRequest = ClientRequest[*InitializedParams] InitializeRequest = ClientRequest[*InitializeParams] ListRootsRequest = ClientRequest[*ListRootsParams] LoggingMessageRequest = ClientRequest[*LoggingMessageParams] diff --git a/mcp/server.go b/mcp/server.go index 571c830f..2206803a 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -54,7 +54,7 @@ type ServerOptions struct { // Optional instructions for connected clients. Instructions string // If non-nil, called when "notifications/initialized" is received. - InitializedHandler func(context.Context, *InitializedServerRequest) + InitializedHandler func(context.Context, *InitializedRequest) // PageSize is the maximum number of items to return in a single page for // list methods (e.g. ListTools). PageSize int From e3e8aaf024fda416ddef97e5f7ff8b3610be952a Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 22 Aug 2025 10:44:05 -0400 Subject: [PATCH 130/221] mcp: align PromptHandler args with other handlers (#348) The second arg is a GetPromptRequest. Fixes #300. --- examples/server/hello/main.go | 4 ++-- mcp/client_list_test.go | 2 +- mcp/mcp_test.go | 6 +++--- mcp/prompt.go | 2 +- mcp/server.go | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/server/hello/main.go b/examples/server/hello/main.go index 04c0e0b4..72d98b21 100644 --- a/examples/server/hello/main.go +++ b/examples/server/hello/main.go @@ -30,11 +30,11 @@ func SayHi(ctx context.Context, req *mcp.CallToolRequest, args HiArgs) (*mcp.Cal }, nil, nil } -func PromptHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.GetPromptParams) (*mcp.GetPromptResult, error) { +func PromptHi(ctx context.Context, req *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ Description: "Code review prompt", Messages: []*mcp.PromptMessage{ - {Role: "user", Content: &mcp.TextContent{Text: "Say hi to " + params.Arguments["name"]}}, + {Role: "user", Content: &mcp.TextContent{Text: "Say hi to " + req.Params.Arguments["name"]}}, }, }, nil } diff --git a/mcp/client_list_test.go b/mcp/client_list_test.go index c1052c25..1449076e 100644 --- a/mcp/client_list_test.go +++ b/mcp/client_list_test.go @@ -126,6 +126,6 @@ func testIterator[T any](t *testing.T, seq iter.Seq2[*T, error], want []*T) { } } -func testPromptHandler(context.Context, *mcp.ServerSession, *mcp.GetPromptParams) (*mcp.GetPromptResult, error) { +func testPromptHandler(context.Context, *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { panic("not implemented") } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 52df1479..0de0325a 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -47,11 +47,11 @@ var codeReviewPrompt = &Prompt{ Arguments: []*PromptArgument{{Name: "Code", Required: true}}, } -func codReviewPromptHandler(_ context.Context, _ *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { +func codReviewPromptHandler(_ context.Context, req *GetPromptRequest) (*GetPromptResult, error) { return &GetPromptResult{ Description: "Code review prompt", Messages: []*PromptMessage{ - {Role: "user", Content: &TextContent{Text: "Please review the following code: " + params.Arguments["Code"]}}, + {Role: "user", Content: &TextContent{Text: "Please review the following code: " + req.Params.Arguments["Code"]}}, }, }, nil } @@ -103,7 +103,7 @@ func TestEndToEnd(t *testing.T) { return nil, nil, errTestFailure }) s.AddPrompt(codeReviewPrompt, codReviewPromptHandler) - s.AddPrompt(&Prompt{Name: "fail"}, func(_ context.Context, _ *ServerSession, _ *GetPromptParams) (*GetPromptResult, error) { + s.AddPrompt(&Prompt{Name: "fail"}, func(_ context.Context, _ *GetPromptRequest) (*GetPromptResult, error) { return nil, errTestFailure }) s.AddResource(resource1, readHandler) diff --git a/mcp/prompt.go b/mcp/prompt.go index 0ecf5528..62f38a36 100644 --- a/mcp/prompt.go +++ b/mcp/prompt.go @@ -9,7 +9,7 @@ import ( ) // A PromptHandler handles a call to prompts/get. -type PromptHandler func(context.Context, *ServerSession, *GetPromptParams) (*GetPromptResult, error) +type PromptHandler func(context.Context, *GetPromptRequest) (*GetPromptResult, error) type serverPrompt struct { prompt *Prompt diff --git a/mcp/server.go b/mcp/server.go index 2206803a..c44dfeb6 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -446,7 +446,7 @@ func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetProm Message: fmt.Sprintf("unknown prompt %q", req.Params.Name), } } - return prompt.handler(ctx, req.Session, req.Params) + return prompt.handler(ctx, req) } func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListToolsResult, error) { From 42f419fff41368866fee4c3a16c5d2856095d8d3 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Fri, 22 Aug 2025 11:01:48 -0400 Subject: [PATCH 131/221] mcp/examples: move elicitation example into example folder (#354) Rename and move the example into the examples folder. --- .../server/elicitation/main.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename mcp/elicitation_example_test.go => examples/server/elicitation/main.go (98%) diff --git a/mcp/elicitation_example_test.go b/examples/server/elicitation/main.go similarity index 98% rename from mcp/elicitation_example_test.go rename to examples/server/elicitation/main.go index 526a4881..59bc25cf 100644 --- a/mcp/elicitation_example_test.go +++ b/examples/server/elicitation/main.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. -package mcp_test +package main import ( "context" @@ -13,7 +13,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func Example_elicitation() { +func main() { ctx := context.Background() clientTransport, serverTransport := mcp.NewInMemoryTransports() From 458794199402520b2cf008f3223d30bce204f99a Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 22 Aug 2025 11:04:01 -0400 Subject: [PATCH 132/221] mcp: add Request.GetExtra (#350) --- mcp/shared.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mcp/shared.go b/mcp/shared.go index 0675ca45..e3ad6ff7 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -404,6 +404,8 @@ type Request interface { isRequest() GetSession() Session GetParams() Params + // GetExtra returns the Extra field for ServerRequests, and nil for ClientRequests. + GetExtra() *RequestExtra } // A ClientRequest is a request to a client. @@ -435,6 +437,9 @@ func (r *ServerRequest[P]) GetSession() Session { return r.Session } func (r *ClientRequest[P]) GetParams() Params { return r.Params } func (r *ServerRequest[P]) GetParams() Params { return r.Params } +func (r *ClientRequest[P]) GetExtra() *RequestExtra { return nil } +func (r *ServerRequest[P]) GetExtra() *RequestExtra { return r.Extra } + func serverRequestFor[P Params](s *ServerSession, p P) *ServerRequest[P] { return &ServerRequest[P]{Session: s, Params: p} } From 18c96d691275b5aadcc3047604bc41a802b294ae Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 22 Aug 2025 11:04:44 -0400 Subject: [PATCH 133/221] mcp: enforce input schema type "object" (#349) The spec makes it clear that input schemas must have type "object". Enforce that. Allow "any" as an input argument type by special-casing it. Fixes #283. --- mcp/mcp_test.go | 10 +++++----- mcp/server.go | 9 +++++++++ mcp/server_test.go | 2 +- mcp/streamable_test.go | 6 +++--- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 0de0325a..5485eff2 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -98,7 +98,7 @@ func TestEndToEnd(t *testing.T) { Name: "greet", Description: "say hi", }, sayHi) - AddTool(s, &Tool{Name: "fail", InputSchema: &jsonschema.Schema{}}, + AddTool(s, &Tool{Name: "fail", InputSchema: &jsonschema.Schema{Type: "object"}}, func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) { return nil, nil, errTestFailure }) @@ -257,7 +257,7 @@ func TestEndToEnd(t *testing.T) { t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) } - s.AddTool(&Tool{Name: "T", InputSchema: &jsonschema.Schema{}}, nopHandler) + s.AddTool(&Tool{Name: "T", InputSchema: &jsonschema.Schema{Type: "object"}}, nopHandler) waitForNotification(t, "tools") s.RemoveTools("T") waitForNotification(t, "tools") @@ -697,7 +697,7 @@ func TestCancellation(t *testing.T) { return nil, nil, nil } cs, _ := basicConnection(t, func(s *Server) { - AddTool(s, &Tool{Name: "slow", InputSchema: &jsonschema.Schema{}}, slowRequest) + AddTool(s, &Tool{Name: "slow", InputSchema: &jsonschema.Schema{Type: "object"}}, slowRequest) }) defer cs.Close() @@ -1496,8 +1496,8 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { // Use two distinct Tool instances with the same name but different // descriptions to ensure the second replaces the first // This case was written specifically to reproduce a bug where duplicate tools where causing jsonschema errors - t1 := &Tool{Name: "dup", Description: "first", InputSchema: &jsonschema.Schema{}} - t2 := &Tool{Name: "dup", Description: "second", InputSchema: &jsonschema.Schema{}} + t1 := &Tool{Name: "dup", Description: "first", InputSchema: &jsonschema.Schema{Type: "object"}} + t2 := &Tool{Name: "dup", Description: "second", InputSchema: &jsonschema.Schema{Type: "object"}} s.AddTool(t1, nopHandler) s.AddTool(t2, nopHandler) }) diff --git a/mcp/server.go b/mcp/server.go index c44dfeb6..caff8dcb 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -162,6 +162,9 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { // discovered until runtime, when the LLM sent bad data. panic(fmt.Errorf("AddTool %q: missing input schema", t.Name)) } + if t.InputSchema.Type != "object" { + panic(fmt.Errorf(`AddTool %q: input schema must have type "object"`, t.Name)) + } st := &serverTool{tool: t, handler: h} // Assume there was a change, since add replaces existing tools. // (It's possible a tool was replaced with an identical one, but not worth checking.) @@ -190,6 +193,12 @@ func ToolFor[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandle // TODO(v0.3.0): test func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { tt := *t + + // Special handling for an "any" input: treat as an empty object. + if reflect.TypeFor[In]() == reflect.TypeFor[any]() && t.InputSchema == nil { + tt.InputSchema = &jsonschema.Schema{Type: "object"} + } + var inputResolved *jsonschema.Resolved if _, err := setSchema[In](&tt.InputSchema, &inputResolved); err != nil { return nil, nil, fmt.Errorf("input schema: %w", err) diff --git a/mcp/server_test.go b/mcp/server_test.go index 1ed4c3cc..81a59615 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -232,7 +232,7 @@ func TestServerPaginateVariousPageSizes(t *testing.T) { } func TestServerCapabilities(t *testing.T) { - tool := &Tool{Name: "t", InputSchema: &jsonschema.Schema{}} + tool := &Tool{Name: "t", InputSchema: &jsonschema.Schema{Type: "object"}} testCases := []struct { name string configureServer func(s *Server) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 603be473..c9b00ed8 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -219,7 +219,7 @@ func testClientReplay(t *testing.T, test clientReplayTest) { // proxy-killing action. serverReadyToKillProxy := make(chan struct{}) serverClosed := make(chan struct{}) - AddTool(server, &Tool{Name: "multiMessageTool", InputSchema: &jsonschema.Schema{}}, + AddTool(server, &Tool{Name: "multiMessageTool", InputSchema: &jsonschema.Schema{Type: "object"}}, func(ctx context.Context, req *CallToolRequest, args map[string]any) (*CallToolResult, any, error) { // Send one message to the request context, and another to a background // context (which will end up on the hanging GET). @@ -353,7 +353,7 @@ func TestServerInitiatedSSE(t *testing.T) { t.Fatalf("client.Connect() failed: %v", err) } defer clientSession.Close() - AddTool(server, &Tool{Name: "testTool", InputSchema: &jsonschema.Schema{}}, + AddTool(server, &Tool{Name: "testTool", InputSchema: &jsonschema.Schema{Type: "object"}}, func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) { return &CallToolResult{}, nil, nil }) @@ -658,7 +658,7 @@ func TestStreamableServerTransport(t *testing.T) { // behavior, if any. server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) server.AddTool( - &Tool{Name: "tool", InputSchema: &jsonschema.Schema{}}, + &Tool{Name: "tool", InputSchema: &jsonschema.Schema{Type: "object"}}, func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { if test.tool != nil { test.tool(t, ctx, req.Session) From b891d953610bb410ef81b311f1b7e1472b0b140d Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 22 Aug 2025 11:05:03 -0400 Subject: [PATCH 134/221] mcp: example middleware wrapping ToolHandler (#351) Add to the middleware example to show how to wrap a ToolHandler. This middleware is very close to, but not the same as, wrapping a ToolHandler directly. The only difference is that the middleware wraps Server.callTool, which looks up the tool by name in the server's list and then calls the handler. That lookup (plus other intervening middleware, of course) is all that distinguishes this way of wrapping from a more direct wrapping. --- mcp/example_middleware_test.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mcp/example_middleware_test.go b/mcp/example_middleware_test.go index 10dda0fa..b6b8e52d 100644 --- a/mcp/example_middleware_test.go +++ b/mcp/example_middleware_test.go @@ -40,13 +40,16 @@ func Example_loggingMiddleware() { "session_id", req.GetSession().ID(), "has_params", req.GetParams() != nil, ) + // Log more for tool calls. + if ctr, ok := req.(*mcp.CallToolRequest); ok { + logger.Info("Calling tool", + "name", ctr.Params.Name, + "args", ctr.Params.Arguments) + } start := time.Now() - result, err := next(ctx, method, req) - duration := time.Since(start) - if err != nil { logger.Error("MCP method failed", "method", method, @@ -62,7 +65,6 @@ func Example_loggingMiddleware() { "has_result", result != nil, ) } - return result, err } } @@ -134,6 +136,7 @@ func Example_loggingMiddleware() { // time=2025-01-01T00:00:00Z level=INFO msg="MCP method started" method=notifications/initialized session_id="" has_params=true // time=2025-01-01T00:00:00Z level=INFO msg="MCP method completed" method=notifications/initialized session_id="" duration_ms=0 has_result=false // time=2025-01-01T00:00:00Z level=INFO msg="MCP method started" method=tools/call session_id="" has_params=true + // time=2025-01-01T00:00:00Z level=INFO msg="Calling tool" name=greet args="{\"name\":\"World\"}" // time=2025-01-01T00:00:00Z level=INFO msg="MCP method completed" method=tools/call session_id="" duration_ms=0 has_result=true // Tool result: Hello, World! } From d16ce9c25f5a8df5691049ccb25607303e518b8c Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Fri, 22 Aug 2025 11:18:05 -0400 Subject: [PATCH 135/221] mcp/examples: moves middleware examples into example folder (#356) For better organization. --- .../client/middleware/main.go | 7 +++---- .../server/middleware/main.go | 4 ++-- mcp/mcp-repo-replace.txt | 9 --------- 3 files changed, 5 insertions(+), 15 deletions(-) rename mcp/example_progress_test.go => examples/client/middleware/main.go (89%) rename mcp/example_middleware_test.go => examples/server/middleware/main.go (98%) delete mode 100644 mcp/mcp-repo-replace.txt diff --git a/mcp/example_progress_test.go b/examples/client/middleware/main.go similarity index 89% rename from mcp/example_progress_test.go rename to examples/client/middleware/main.go index 304c838a..6ae87df0 100644 --- a/mcp/example_progress_test.go +++ b/examples/client/middleware/main.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. -package mcp_test +package main import ( "context" @@ -15,10 +15,9 @@ var nextProgressToken atomic.Int64 // This middleware function adds a progress token to every outgoing request // from the client. -func Example_progressMiddleware() { - c := mcp.NewClient(testImpl, nil) +func main() { + c := mcp.NewClient(&mcp.Implementation{Name: "test"}, nil) c.AddSendingMiddleware(addProgressToken[*mcp.ClientSession]) - _ = c } func addProgressToken[S mcp.Session](h mcp.MethodHandler) mcp.MethodHandler { diff --git a/mcp/example_middleware_test.go b/examples/server/middleware/main.go similarity index 98% rename from mcp/example_middleware_test.go rename to examples/server/middleware/main.go index b6b8e52d..224c8c6f 100644 --- a/mcp/example_middleware_test.go +++ b/examples/server/middleware/main.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. -package mcp_test +package main import ( "context" @@ -16,7 +16,7 @@ import ( ) // This example demonstrates server side logging using the mcp.Middleware system. -func Example_loggingMiddleware() { +func main() { // Create a logger for demonstration purposes. logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelInfo, diff --git a/mcp/mcp-repo-replace.txt b/mcp/mcp-repo-replace.txt deleted file mode 100644 index 3409dd7a..00000000 --- a/mcp/mcp-repo-replace.txt +++ /dev/null @@ -1,9 +0,0 @@ -"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"==>"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" -github.com/modelcontextprotocol/go-sdk/internal/xcontext==>github.com/modelcontextprotocol/go-sdk/internal/xcontext -github.com/modelcontextprotocol/go-sdk/internal==>github.com/modelcontextprotocol/go-sdk/internal -github.com/modelcontextprotocol/go-sdk/jsonschema==>github.com/modelcontextprotocol/go-sdk/jsonschema -github.com/modelcontextprotocol/go-sdk/examples==>github.com/modelcontextprotocol/go-sdk/examples -github.com/modelcontextprotocol/go-sdk/design==>github.com/modelcontextprotocol/go-sdk/design -github.com/modelcontextprotocol/go-sdk/mcp==>github.com/modelcontextprotocol/go-sdk/mcp -governed by an MIT-style==>governed by an MIT-style -regex:Copyright (20\d\d) The Go Authors==>Copyright \1 The Go MCP SDK Authors From c631641296566830c4fdab087fbf6785f7c76bf4 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 22 Aug 2025 11:35:50 -0400 Subject: [PATCH 136/221] mcp: check output schema type (#358) Server.AddTool now checks that an output schema, if any, has type "object". Also add doc and a test. --- mcp/server.go | 27 +++++++++++++++------------ mcp/server_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index caff8dcb..febf2e6e 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -145,15 +145,17 @@ func (s *Server) RemovePrompts(names ...string) { // AddTool adds a [Tool] to the server, or replaces one with the same name. // The Tool argument must not be modified after this call. // -// The tool's input schema must be non-nil. For a tool that takes no input, -// or one where any input is valid, set [Tool.InputSchema] to the empty schema, -// &jsonschema.Schema{}. +// The tool's input schema must be non-nil and have the type "object". For a tool +// that takes no input, or one where any input is valid, set [Tool.InputSchema] to +// &jsonschema.Schema{Type: "object"}. +// +// If present, the output schema must also have type "object". // // When the handler is invoked as part of a CallTool request, req.Params.Arguments // will be a json.RawMessage. Unmarshaling the arguments and validating them against the // input schema are the handler author's responsibility. // -// Most users will prefer the top-level function [AddTool]. +// Most users should use the top-level function [AddTool]. func (s *Server) AddTool(t *Tool, h ToolHandler) { if t.InputSchema == nil { // This prevents the tool author from forgetting to write a schema where @@ -165,6 +167,9 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { if t.InputSchema.Type != "object" { panic(fmt.Errorf(`AddTool %q: input schema must have type "object"`, t.Name)) } + if t.OutputSchema != nil && t.OutputSchema.Type != "object" { + panic(fmt.Errorf(`AddTool %q: output schema must have type "object"`, t.Name)) + } st := &serverTool{tool: t, handler: h} // Assume there was a change, since add replaces existing tools. // (It's possible a tool was replaced with an identical one, but not worth checking.) @@ -176,9 +181,12 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { // ToolFor returns a shallow copy of t and a [ToolHandler] that wraps h. // If the tool's input schema is nil, it is set to the schema inferred from the In -// type parameter, using [jsonschema.For]. +// type parameter, using [jsonschema.For]. The In type parameter must be a map +// or a struct, so that its inferred JSON Schema has type "object". +// // If the tool's output schema is nil and the Out type parameter is not the empty -// interface, then the output schema is set to the schema inferred from Out. +// interface, then the output schema is set to the schema inferred from Out, which +// must be a map or a struct. // // Most users will call [AddTool]. Use [ToolFor] if you wish to modify the tool's // schemas or wrap the ToolHandler before calling [Server.AddTool]. @@ -305,12 +313,7 @@ func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) } // AddTool adds a tool and handler to the server. -// -// A shallow copy of the tool is made first. -// If the tool's input schema is nil, the copy's input schema is set to the schema -// inferred from the In type parameter, using [jsonschema.For]. -// If the tool's output schema is nil and the Out type parameter is not the empty -// interface, then the copy's output schema is set to the schema inferred from Out. +// It is a convenience for s.AddTool(ToolFor(t, h)). func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { s.AddTool(ToolFor(t, h)) } diff --git a/mcp/server_test.go b/mcp/server_test.go index 81a59615..cb1c05dc 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -456,3 +456,36 @@ func TestServerSessionkeepaliveCancelOverwritten(t *testing.T) { t.Fatal("expected ServerSession.keepaliveCancel to be nil after we manually niled it and re-initialized") } } + +// panicks reports whether f() panics. +func panics(f func()) (b bool) { + defer func() { + b = recover() != nil + }() + f() + return false +} + +func TestAddTool(t *testing.T) { + // AddTool should panic if In or Out are not JSON objects. + s := NewServer(testImpl, nil) + if !panics(func() { + AddTool(s, &Tool{Name: "T1"}, func(context.Context, *CallToolRequest, string) (*CallToolResult, any, error) { return nil, nil, nil }) + }) { + t.Error("bad In: expected panic") + } + if panics(func() { + AddTool(s, &Tool{Name: "T2"}, func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) { + return nil, nil, nil + }) + }) { + t.Error("good In: expected no panic") + } + if !panics(func() { + AddTool(s, &Tool{Name: "T2"}, func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, int, error) { + return nil, 0, nil + }) + }) { + t.Error("bad Out: expected panic") + } +} From 9754a2aa86207a2f2d2fa66fe1d9989ba47f2ef5 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 22 Aug 2025 17:46:25 +0000 Subject: [PATCH 137/221] examples: add an everything example, and simplify hello Our 'hello' example should be as simple as possible. On the other hand, we should have an 'everything' example that exercises (almost) everything we offer. Using the everything example, I did some final testing using the inspector. This turned up a couple rough edges related to JSON null that I addressed by preferring empty slices. For #33 --- examples/server/everything/main.go | 200 +++++++++++++++++++++++++++++ examples/server/hello/main.go | 98 ++++---------- mcp/server.go | 13 ++ 3 files changed, 240 insertions(+), 71 deletions(-) create mode 100644 examples/server/everything/main.go diff --git a/examples/server/everything/main.go b/examples/server/everything/main.go new file mode 100644 index 00000000..d2b7b337 --- /dev/null +++ b/examples/server/everything/main.go @@ -0,0 +1,200 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The everything server implements all supported features of an MCP server. +package main + +import ( + "context" + "flag" + "fmt" + "log" + "net/http" + "net/url" + "os" + "strings" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") + +func main() { + flag.Parse() + + opts := &mcp.ServerOptions{ + Instructions: "Use this server!", + CompletionHandler: complete, // support completions by setting this handler + } + + server := mcp.NewServer(&mcp.Implementation{Name: "everything"}, opts) + + // Add tools that exercise different features of the protocol. + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, contentTool) + mcp.AddTool(server, &mcp.Tool{Name: "greet (structured)"}, structuredTool) // returns structured output + mcp.AddTool(server, &mcp.Tool{Name: "ping"}, pingingTool) // performs a ping + mcp.AddTool(server, &mcp.Tool{Name: "log"}, loggingTool) // performs a log + mcp.AddTool(server, &mcp.Tool{Name: "sample"}, samplingTool) // performs sampling + mcp.AddTool(server, &mcp.Tool{Name: "elicit"}, elicitingTool) // performs elicitation + mcp.AddTool(server, &mcp.Tool{Name: "roots"}, rootsTool) // lists roots + + // Add a basic prompt. + server.AddPrompt(&mcp.Prompt{Name: "greet"}, prompt) + + // Add an embedded resource. + server.AddResource(&mcp.Resource{ + Name: "info", + MIMEType: "text/plain", + URI: "embedded:info", + }, embeddedResource) + + // Serve over stdio, or streamable HTTP if -http is set. + if *httpAddr != "" { + handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, nil) + log.Printf("MCP handler listening at %s", *httpAddr) + http.ListenAndServe(*httpAddr, handler) + } else { + t := &mcp.LoggingTransport{Transport: &mcp.StdioTransport{}, Writer: os.Stderr} + if err := server.Run(context.Background(), t); err != nil { + log.Printf("Server failed: %v", err) + } + } +} + +func prompt(ctx context.Context, req *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return &mcp.GetPromptResult{ + Description: "Hi prompt", + Messages: []*mcp.PromptMessage{ + { + Role: "user", + Content: &mcp.TextContent{Text: "Say hi to " + req.Params.Arguments["name"]}, + }, + }, + }, nil +} + +var embeddedResources = map[string]string{ + "info": "This is the hello example server.", +} + +func embeddedResource(_ context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + u, err := url.Parse(req.Params.URI) + if err != nil { + return nil, err + } + if u.Scheme != "embedded" { + return nil, fmt.Errorf("wrong scheme: %q", u.Scheme) + } + key := u.Opaque + text, ok := embeddedResources[key] + if !ok { + return nil, fmt.Errorf("no embedded resource named %q", key) + } + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{ + {URI: req.Params.URI, MIMEType: "text/plain", Text: text}, + }, + }, nil +} + +type args struct { + Name string `json:"name" jsonschema:"the name to say hi to"` +} + +// contentTool is a tool that returns unstructured content. +// +// Since its output type is 'any', no output schema is created. +func contentTool(ctx context.Context, req *mcp.CallToolRequest, args args) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Hi " + args.Name}, + }, + }, nil, nil +} + +type result struct { + Message string `json:"message" jsonschema:"the message to convey"` +} + +// structuredTool returns a structured result. +func structuredTool(ctx context.Context, req *mcp.CallToolRequest, args *args) (*mcp.CallToolResult, *result, error) { + return nil, &result{Message: "Hi " + args.Name}, nil +} + +func pingingTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + if err := req.Session.Ping(ctx, nil); err != nil { + return nil, nil, fmt.Errorf("ping failed") + } + return nil, nil, nil +} + +func loggingTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + if err := req.Session.Log(ctx, &mcp.LoggingMessageParams{ + Data: "something happened!", + Level: "error", + }); err != nil { + return nil, nil, fmt.Errorf("log failed") + } + return nil, nil, nil +} + +func rootsTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + res, err := req.Session.ListRoots(ctx, nil) + if err != nil { + return nil, nil, fmt.Errorf("listing roots failed: %v", err) + } + var allroots []string + for _, r := range res.Roots { + allroots = append(allroots, fmt.Sprintf("%s:%s", r.Name, r.URI)) + } + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: strings.Join(allroots, ",")}, + }, + }, nil, nil +} + +func samplingTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + res, err := req.Session.CreateMessage(ctx, new(mcp.CreateMessageParams)) + if err != nil { + return nil, nil, fmt.Errorf("sampling failed: %v", err) + } + return &mcp.CallToolResult{ + Content: []mcp.Content{ + res.Content, + }, + }, nil, nil +} + +func elicitingTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + res, err := req.Session.Elicit(ctx, &mcp.ElicitParams{ + Message: "provide a random string", + RequestedSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "random": {Type: "string"}, + }, + }, + }) + if err != nil { + return nil, nil, fmt.Errorf("eliciting failed: %v", err) + } + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: res.Content["random"].(string)}, + }, + }, nil, nil +} + +func complete(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { + return &mcp.CompleteResult{ + Completion: mcp.CompletionResultDetails{ + Total: 1, + Values: []string{req.Params.Argument.Value + "x"}, + }, + }, nil +} diff --git a/examples/server/hello/main.go b/examples/server/hello/main.go index 72d98b21..796feff8 100644 --- a/examples/server/hello/main.go +++ b/examples/server/hello/main.go @@ -2,89 +2,45 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. +// The hello server contains a single tool that says hi to the user. +// +// It runs over the stdio transport. package main import ( "context" - "flag" - "fmt" "log" - "net/http" - "net/url" - "os" "github.com/modelcontextprotocol/go-sdk/mcp" ) -var httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") - -type HiArgs struct { - Name string `json:"name" jsonschema:"the name to say hi to"` -} - -func SayHi(ctx context.Context, req *mcp.CallToolRequest, args HiArgs) (*mcp.CallToolResult, any, error) { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + args.Name}, - }, - }, nil, nil -} - -func PromptHi(ctx context.Context, req *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { - return &mcp.GetPromptResult{ - Description: "Code review prompt", - Messages: []*mcp.PromptMessage{ - {Role: "user", Content: &mcp.TextContent{Text: "Say hi to " + req.Params.Arguments["name"]}}, - }, - }, nil -} - func main() { - flag.Parse() - + // Create a server with a single tool that says "Hi". server := mcp.NewServer(&mcp.Implementation{Name: "greeter"}, nil) - mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) - server.AddPrompt(&mcp.Prompt{Name: "greet"}, PromptHi) - server.AddResource(&mcp.Resource{ - Name: "info", - MIMEType: "text/plain", - URI: "embedded:info", - }, handleEmbeddedResource) - - if *httpAddr != "" { - handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { - return server - }, nil) - log.Printf("MCP handler listening at %s", *httpAddr) - http.ListenAndServe(*httpAddr, handler) - } else { - t := &mcp.LoggingTransport{Transport: &mcp.StdioTransport{}, Writer: os.Stderr} - if err := server.Run(context.Background(), t); err != nil { - log.Printf("Server failed: %v", err) - } - } -} -var embeddedResources = map[string]string{ - "info": "This is the hello example server.", -} - -func handleEmbeddedResource(_ context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { - u, err := url.Parse(req.Params.URI) - if err != nil { - return nil, err - } - if u.Scheme != "embedded" { - return nil, fmt.Errorf("wrong scheme: %q", u.Scheme) + // Using the generic AddTool automatically populates the the input and output + // schema of the tool. + // + // The schema considers 'json' and 'jsonschema' struct tags to get argument + // names and descriptions. + type args struct { + Name string `json:"name" jsonschema:"the person to greet"` } - key := u.Opaque - text, ok := embeddedResources[key] - if !ok { - return nil, fmt.Errorf("no embedded resource named %q", key) + mcp.AddTool(server, &mcp.Tool{ + Name: "greet", + Description: "say hi", + }, func(ctx context.Context, req *mcp.CallToolRequest, args args) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Hi " + args.Name}, + }, + }, nil, nil + }) + + // server.Run runs the server on the given transport. + // + // In this case, the server communicates over stdin/stdout. + if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { + log.Printf("Server failed: %v", err) } - return &mcp.ReadResourceResult{ - Contents: []*mcp.ResourceContents{ - {URI: req.Params.URI, MIMEType: "text/plain", Text: text}, - }, - }, nil } diff --git a/mcp/server.go b/mcp/server.go index febf2e6e..eef6a5ad 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -57,6 +57,8 @@ type ServerOptions struct { InitializedHandler func(context.Context, *InitializedRequest) // PageSize is the maximum number of items to return in a single page for // list methods (e.g. ListTools). + // + // If zero, defaults to [DefaultPageSize]. PageSize int // If non-nil, called when "notifications/roots/list_changed" is received. RootsListChangedHandler func(context.Context, *RootsListChangedRequest) @@ -266,6 +268,9 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan if res == nil { res = &CallToolResult{} } + if res.Content == nil { + res.Content = []Content{} // avoid returning 'null' + } res.StructuredContent = out if elemZero != nil { // Avoid typed nil, which will serialize as JSON null. @@ -843,6 +848,14 @@ func (ss *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) // CreateMessage sends a sampling request to the client. func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessageParams) (*CreateMessageResult, error) { + if params == nil { + params = &CreateMessageParams{Messages: []*SamplingMessage{}} + } + if params.Messages == nil { + p2 := *params + p2.Messages = []*SamplingMessage{} // avoid JSON "null" + params = &p2 + } return handleSend[*CreateMessageResult](ctx, methodCreateMessage, newServerRequest(ss, orZero[Params](params))) } From fb69971650c79df831a27703a1c05f4a0743de9a Mon Sep 17 00:00:00 2001 From: Hongxiang Jiang <58842577+h9jiang@users.noreply.github.com> Date: Fri, 22 Aug 2025 15:26:52 -0400 Subject: [PATCH 138/221] mcp: cleanup transport after keepalive ping fails An onClose function is passed to the ServerSession and ClientSession to help cleanup resources from the caller. The onClose function will be executed as part of the ServerSession and ClientSession closure. Fixes #258 --- mcp/client.go | 16 +++++++--- mcp/server.go | 20 +++++++++--- mcp/streamable.go | 15 +++++++++ mcp/streamable_test.go | 71 ++++++++++++++++++++++++++++++++++++++++++ mcp/transport.go | 7 +++-- 5 files changed, 118 insertions(+), 11 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 3b1741b3..1ed3b048 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -76,9 +76,9 @@ type ClientOptions struct { // bind implements the binder[*ClientSession] interface, so that Clients can // be connected using [connect]. -func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState) *ClientSession { +func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession { assert(mcpConn != nil && conn != nil, "nil connection") - cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c} + cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c, onClose: onClose} if state != nil { cs.state = *state } @@ -130,7 +130,7 @@ func (c *Client) capabilities() *ClientCapabilities { // server, calls or notifications will return an error wrapping // [ErrConnectionClosed]. func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) { - cs, err = connect(ctx, t, c, (*clientSessionState)(nil)) + cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil) if err != nil { return nil, err } @@ -173,6 +173,8 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio // Call [ClientSession.Close] to close the connection, or await server // termination with [ClientSession.Wait]. type ClientSession struct { + onClose func() + conn *jsonrpc2.Connection client *Client keepaliveCancel context.CancelFunc @@ -208,7 +210,13 @@ func (cs *ClientSession) Close() error { if cs.keepaliveCancel != nil { cs.keepaliveCancel() } - return cs.conn.Close() + err := cs.conn.Close() + + if cs.onClose != nil { + cs.onClose() + } + + return err } // Wait waits for the connection to be closed by the server. diff --git a/mcp/server.go b/mcp/server.go index eef6a5ad..75632f7b 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -698,9 +698,9 @@ func (s *Server) Run(ctx context.Context, t Transport) error { // bind implements the binder[*ServerSession] interface, so that Servers can // be connected using [connect]. -func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState) *ServerSession { +func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState, onClose func()) *ServerSession { assert(mcpConn != nil && conn != nil, "nil connection") - ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s} + ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s, onClose: onClose} if state != nil { ss.state = *state } @@ -727,6 +727,8 @@ func (s *Server) disconnect(cc *ServerSession) { // ServerSessionOptions configures the server session. type ServerSessionOptions struct { State *ServerSessionState + + onClose func() } // Connect connects the MCP server over the given transport and starts handling @@ -739,10 +741,12 @@ type ServerSessionOptions struct { // If opts.State is non-nil, it is the initial state for the server. func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) { var state *ServerSessionState + var onClose func() if opts != nil { state = opts.State + onClose = opts.onClose } - return connect(ctx, t, s, state) + return connect(ctx, t, s, state, onClose) } // TODO: (nit) move all ServerSession methods below the ServerSession declaration. @@ -809,6 +813,8 @@ func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] { // Call [ServerSession.Close] to close the connection, or await client // termination with [ServerSession.Wait]. type ServerSession struct { + onClose func() + server *Server conn *jsonrpc2.Connection mcpConn Connection @@ -1043,7 +1049,13 @@ func (ss *ServerSession) Close() error { // Close is idempotent and conn.Close() handles concurrent calls correctly ss.keepaliveCancel() } - return ss.conn.Close() + err := ss.conn.Close() + + if ss.onClose != nil { + ss.onClose() + } + + return err } // Wait waits for the connection to be closed by the client. diff --git a/mcp/streamable.go b/mcp/streamable.go index 99fbe422..f56b7084 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -40,6 +40,8 @@ type StreamableHTTPHandler struct { getServer func(*http.Request) *Server opts StreamableHTTPOptions + onTransportDeletion func(sessionID string) // for testing only + mu sync.Mutex // TODO: we should store the ServerSession along with the transport, because // we need to cancel keepalive requests when closing the transport. @@ -283,6 +285,19 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque connectOpts = &ServerSessionOptions{ State: state, } + } else { + // Cleanup is only required in stateful mode, as transportation is + // not stored in the map otherwise. + connectOpts = &ServerSessionOptions{ + onClose: func() { + h.mu.Lock() + delete(h.transports, transport.SessionID) + h.mu.Unlock() + if h.onTransportDeletion != nil { + h.onTransportDeletion(transport.SessionID) + } + }, + } } // Pass req.Context() here, to allow middleware to add context values. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index c9b00ed8..52f61720 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -8,6 +8,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "maps" @@ -332,6 +333,76 @@ func testClientReplay(t *testing.T, test clientReplayTest) { } } +func TestServerTransportCleanup(t *testing.T) { + server := NewServer(testImpl, &ServerOptions{KeepAlive: 10 * time.Millisecond}) + + nClient := 3 + + var mu sync.Mutex + var id int = -1 // session id starting from "0", "1", "2"... + chans := make(map[string]chan struct{}, nClient) + + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + GetSessionID: func() string { + mu.Lock() + defer mu.Unlock() + id++ + if id == nClient { + t.Errorf("creating more than %v session", nClient) + } + chans[fmt.Sprint(id)] = make(chan struct{}, 1) + return fmt.Sprint(id) + }, + }) + + handler.onTransportDeletion = func(sessionID string) { + chans[sessionID] <- struct{}{} + } + + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Spin up clients connect to the same server but refuse to ping request. + for range nClient { + client := NewClient(testImpl, nil) + pingMiddleware := func(next MethodHandler) MethodHandler { + return func( + ctx context.Context, + method string, + req Request, + ) (Result, error) { + if method == "ping" { + return &emptyResult{}, errors.New("ping error") + } + return next(ctx, method, req) + } + } + client.AddReceivingMiddleware(pingMiddleware) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer clientSession.Close() + } + + for _, ch := range chans { + select { + case <-ctx.Done(): + t.Errorf("did not capture transport deletion event from all session in 10 seconds") + case <-ch: // Received transport deletion signal of this session + } + } + + handler.mu.Lock() + if len(handler.transports) != 0 { + t.Errorf("want empty transports map, find %v entries from handler's transports map", len(handler.transports)) + } + handler.mu.Unlock() +} + // TestServerInitiatedSSE verifies that the persistent SSE connection remains // open and can receive server-initiated events. func TestServerInitiatedSSE(t *testing.T) { diff --git a/mcp/transport.go b/mcp/transport.go index 2bcd8d7d..fac640a6 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -122,7 +122,8 @@ func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) { } type binder[T handler, State any] interface { - bind(Connection, *jsonrpc2.Connection, State) T + // TODO(rfindley): the bind API has gotten too complicated. Simplify. + bind(Connection, *jsonrpc2.Connection, State, func()) T disconnect(T) } @@ -130,7 +131,7 @@ type handler interface { handle(ctx context.Context, req *jsonrpc.Request) (any, error) } -func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State) (H, error) { +func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func()) (H, error) { var zero H mcpConn, err := t.Connect(ctx) if err != nil { @@ -143,7 +144,7 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, preempter canceller ) bind := func(conn *jsonrpc2.Connection) jsonrpc2.Handler { - h = b.bind(mcpConn, conn, s) + h = b.bind(mcpConn, conn, s, onClose) preempter.conn = conn return jsonrpc2.HandlerFunc(h.handle) } From f37e549e2584381b023db70df02cdfe5a586c16d Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 22 Aug 2025 19:52:37 +0000 Subject: [PATCH 139/221] mcp: polish package doc for v0.3.0 Do a pass through the package doc for v0.3.0. --- README.md | 4 +- internal/readme/README.src.md | 4 +- mcp/mcp.go | 98 +++++++++++++++++++++++++++-------- mcp/server.go | 22 +++++--- 4 files changed, 95 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 299245cc..bffd48aa 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ -# MCP Go SDK v0.2.0 +# MCP Go SDK v0.3.0 [![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/modelcontextprotocol/go-sdk) @@ -7,7 +7,7 @@ This version contains breaking changes. See the [release notes]( -https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.2.0) for details. +https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.3.0) for details. [![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index de5dd48a..bafcbc73 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -1,4 +1,4 @@ -# MCP Go SDK v0.2.0 +# MCP Go SDK v0.3.0 [![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/modelcontextprotocol/go-sdk) @@ -6,7 +6,7 @@ This version contains breaking changes. See the [release notes]( -https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.2.0) for details. +https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.3.0) for details. [![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) diff --git a/mcp/mcp.go b/mcp/mcp.go index a22748c9..839f9199 100644 --- a/mcp/mcp.go +++ b/mcp/mcp.go @@ -7,32 +7,84 @@ // The mcp package provides an SDK for writing model context protocol clients // and servers. // -// To get started, create either a [Client] or [Server], and connect it to a -// peer using a [Transport]. The diagram below illustrates how this works: +// To get started, create either a [Client] or [Server], add features to it +// using `AddXXX` functions, and connect it to a peer using a [Transport]. +// +// For example, to run a simple server on the [StdioTransport]: +// +// server := mcp.NewServer(&mcp.Implementation{Name: "greeter"}, nil) +// +// // Using the generic AddTool automatically populates the the input and output +// // schema of the tool. +// type args struct { +// Name string `json:"name" jsonschema:"the person to greet"` +// } +// mcp.AddTool(server, &mcp.Tool{ +// Name: "greet", +// Description: "say hi", +// }, func(ctx context.Context, req *mcp.CallToolRequest, args args) (*mcp.CallToolResult, any, error) { +// return &mcp.CallToolResult{ +// Content: []mcp.Content{ +// &mcp.TextContent{Text: "Hi " + args.Name}, +// }, +// }, nil, nil +// }) +// +// // Run the server on the stdio transport. +// if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { +// log.Printf("Server failed: %v", err) +// } +// +// To connect to this server, use the [CommandTransport]: +// +// client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) +// transport := &mcp.CommandTransport{Command: exec.Command("myserver")} +// session, err := client.Connect(ctx, transport, nil) +// if err != nil { +// log.Fatal(err) +// } +// defer session.Close() +// +// params := &mcp.CallToolParams{ +// Name: "greet", +// Arguments: map[string]any{"name": "you"}, +// } +// res, err := session.CallTool(ctx, params) +// if err != nil { +// log.Fatalf("CallTool failed: %v", err) +// } +// +// # Clients, servers, and sessions +// +// In this SDK, both a [Client] and [Server] may handle many concurrent +// connections. Each time a client or server is connected to a peer using a +// [Transport], it creates a new session (either a [ClientSession] or +// [ServerSession]): // // Client Server // ⇅ (jsonrpc2) ⇅ // ClientSession ⇄ Client Transport ⇄ Server Transport ⇄ ServerSession // -// A [Client] is an MCP client, which can be configured with various client -// capabilities. Clients may be connected to a [Server] instance -// using the [Client.Connect] method. -// -// Similarly, a [Server] is an MCP server, which can be configured with various -// server capabilities. Servers may be connected to one or more [Client] -// instances using the [Server.Connect] method, which creates a -// [ServerSession]. -// -// A [Transport] connects a bidirectional [Connection] of jsonrpc2 messages. In -// practice, transports in the MCP spec are are either client transports or -// server transports. For example, the [StdioTransport] is a server transport -// that communicates over stdin/stdout, and its counterpart is a -// [CommandTransport] that communicates with a subprocess over its -// stdin/stdout. -// -// Some transports may hide more complicated details, such as an -// [SSEClientTransport], which reads messages via server-sent events on a -// hanging GET request, and writes them to a POST endpoint. Users of this SDK -// may define their own custom Transports by implementing the [Transport] -// interface. +// The session types expose an API to interact with its peer. For example, +// [ClientSession.CallTool] or [ServerSession.ListRoots]. +// +// # Adding features +// +// Add MCP servers to your Client or Server using AddXXX methods (for example +// [Client.AddRoot] or [Server.AddPrompt]). If any peers are connected when +// AddXXX is called, they will receive a corresponding change notification +// (for example notifications/roots/list_changed). +// +// Adding tools is special: tools may be bound to ordinary Go functions by +// using the top-level generic [AddTool] function, which allows specifying an +// input and output type. When AddTool is used, the tool's input schema and +// output schema are automatically populated, and inputs are automatically +// validated. As a special case, if the output type is 'any', no output schema +// is generated. +// +// func double(_ context.Context, _ *mcp.CallToolRequest, in In) (*mcp.CallToolResponse, Out, error) { +// return nil, Out{Answer: 2*in.Number}, nil +// } +// ... +// mcp.AddTool(&mcp.Tool{Name: "double", Description: "double a number"}, double) package mcp diff --git a/mcp/server.go b/mcp/server.go index 75632f7b..5020823f 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -182,16 +182,17 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { } // ToolFor returns a shallow copy of t and a [ToolHandler] that wraps h. +// // If the tool's input schema is nil, it is set to the schema inferred from the In // type parameter, using [jsonschema.For]. The In type parameter must be a map // or a struct, so that its inferred JSON Schema has type "object". // -// If the tool's output schema is nil and the Out type parameter is not the empty -// interface, then the output schema is set to the schema inferred from Out, which -// must be a map or a struct. +// For tools that don't return structured output, Out should be 'any'. +// Otherwise, if the tool's output schema is nil the output schema is set to +// the schema inferred from Out, which must be a map or a struct. // -// Most users will call [AddTool]. Use [ToolFor] if you wish to modify the tool's -// schemas or wrap the ToolHandler before calling [Server.AddTool]. +// Most users will call [AddTool]. Use [ToolFor] if you wish to modify the +// tool's schemas or wrap the ToolHandler before calling [Server.AddTool]. func ToolFor[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler) { tt, hh, err := toolForErr(t, h) if err != nil { @@ -317,7 +318,16 @@ func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) return zero, err } -// AddTool adds a tool and handler to the server. +// AddTool adds a tool and typed tool handler to the server. +// +// If the tool's input schema is nil, it is set to the schema inferred from the +// In type parameter, using [jsonschema.For]. The In type parameter must be a +// map or a struct, so that its inferred JSON Schema has type "object". +// +// For tools that don't return structured output, Out should be 'any'. +// Otherwise, if the tool's output schema is nil the output schema is set to +// the schema inferred from Out, which must be a map or a struct. +// // It is a convenience for s.AddTool(ToolFor(t, h)). func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { s.AddTool(ToolFor(t, h)) From 73dd76b614713a64fc802b39938189eed56283cc Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 25 Aug 2025 14:00:24 -0400 Subject: [PATCH 140/221] go.mod: update to fixed jsonschema (#370) Update to a version that doesn't copy descriptions of subschemas. Fixes #366. --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 17bddeb6..ebcdc591 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.23.0 require ( github.com/google/go-cmp v0.7.0 - github.com/google/jsonschema-go v0.2.0 + github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76 github.com/yosida95/uritemplate/v3 v3.0.2 golang.org/x/tools v0.34.0 ) diff --git a/go.sum b/go.sum index a2edf9ad..9ae7018a 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.2.0 h1:Uh19091iHC56//WOsAd1oRg6yy1P9BpSvpjOL6RcjLQ= github.com/google/jsonschema-go v0.2.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76 h1:mBlBwtDebdDYr+zdop8N62a44g+Nbv7o2KjWyS1deR4= +github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= From c1c22925da095a5de9844bd636f9237e4a8646c4 Mon Sep 17 00:00:00 2001 From: Jim Clark Date: Mon, 25 Aug 2025 12:00:36 -0700 Subject: [PATCH 141/221] relax requirements on resource URIs (#365) remove checks for URI scheme; nothing in the spec requires one # Background When adding resource templates, we are currently doing checks for schemes on absolute uris. It is still common for MCP servers to use custom uris that would not pass these checks. For example, if the official GitHub MCP server were to use this sdk, it would show failures like: ``` "repo://{owner}/{repo}/contents{/path*}": parse "repo://{owner}/{repo}/contents{/path*}": invalid character "{" in host name "repo://{owner}/{repo}/refs/heads/{branch}/contents{/path*}": parse "repo://{owner}/{repo}/refs/heads/{branch}/contents{/path*}": invalid character "{" in host name "repo://{owner}/{repo}/sha/{sha}/contents{/path*}": parse "repo://{owner}/{repo}/sha/{sha}/contents{/path*}": invalid character "{" in host name invalid resource template uri "repo://{owner}/{repo}/refs/pull/{prNumber}/head/contents{/path*}": parse "repo://{owner}/{repo}/refs/pull/{prNumber}/head/contents{/path*}": invalid character "{" in host name "repo://{owner}/{repo}/refs/tags/{tag}/contents{/path*}": parse "repo://{owner}/{repo}/refs/tags/{tag}/contents{/path*}": invalid character "{" in host name ``` The wikipedia MCP would also show problems with missing schemes. ``` "/search/{query}": "/article/{title}": "/summary/{title}": "/summary/{title}/query/{query}/length/{max_length}": "/summary/{title}/section/{section_title}/length/{max_length}": "/sections/{title}": "/links/{title}": "/facts/{title}/topic/{topic_within_article}/count/{count}": "/coordinates/{title}": ``` --- mcp/server.go | 14 +------------- mcp/server_test.go | 2 -- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 5020823f..c1440339 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -345,13 +345,9 @@ func (s *Server) RemoveTools(names ...string) { func (s *Server) AddResource(r *Resource, h ResourceHandler) { s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, func() bool { - u, err := url.Parse(r.URI) - if err != nil { + if _, err := url.Parse(r.URI); err != nil { panic(err) // url.Parse includes the URI in the error } - if !u.IsAbs() { - panic(fmt.Errorf("URI %s needs a scheme", r.URI)) - } s.resources.add(&serverResource{r, h}) return true }) @@ -374,14 +370,6 @@ func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) { if err != nil { panic(fmt.Errorf("URI template %q is invalid: %w", t.URITemplate, err)) } - // Ensure the URI template has a valid scheme - u, err := url.Parse(t.URITemplate) - if err != nil { - panic(err) // url.Parse includes the URI in the error - } - if !u.IsAbs() { - panic(fmt.Errorf("URI template %q needs a scheme", t.URITemplate)) - } s.resourceTemplates.add(&serverResourceTemplate{t, h}) return true }) diff --git a/mcp/server_test.go b/mcp/server_test.go index cb1c05dc..7db40738 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -381,8 +381,6 @@ func TestServerAddResourceTemplate(t *testing.T) { }{ {"ValidFileTemplate", "file:///{a}/{b}", false}, {"ValidCustomScheme", "myproto:///{a}", false}, - {"MissingScheme1", "://example.com/{path}", true}, - {"MissingScheme2", "/api/v1/users/{id}", true}, {"EmptyVariable", "file:///{}/{b}", true}, {"UnclosedVariable", "file:///{a", true}, } From 03c51134717482fba89fd93d0fc916ce0208414c Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 25 Aug 2025 16:48:10 -0400 Subject: [PATCH 142/221] examples/server/memory: fix misleading code (#369) Tools shouldn't both set StructuredContent and return a typed output. Remove the assignment. --- examples/server/memory/kb.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/server/memory/kb.go b/examples/server/memory/kb.go index c6a59ec0..ad83ca0b 100644 --- a/examples/server/memory/kb.go +++ b/examples/server/memory/kb.go @@ -442,11 +442,6 @@ func (k knowledgeBase) CreateEntities(ctx context.Context, req *mcp.CallToolRequ res.Content = []mcp.Content{ &mcp.TextContent{Text: "Entities created successfully"}, } - - res.StructuredContent = CreateEntitiesResult{ - Entities: entities, - } - return &res, CreateEntitiesResult{Entities: entities}, nil } From 392f719bd1956e7601cf85f7a9b24c7010cffb4c Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Tue, 26 Aug 2025 09:13:30 -0400 Subject: [PATCH 143/221] mcp/examples: move server example into example folder (#357) For better organization. --- .../server/basic/main.go | 20 ++----------------- mcp/client_list_test.go | 15 ++++++++++++-- mcp/cmd_test.go | 12 +++++++++++ 3 files changed, 27 insertions(+), 20 deletions(-) rename mcp/server_example_test.go => examples/server/basic/main.go (68%) diff --git a/mcp/server_example_test.go b/examples/server/basic/main.go similarity index 68% rename from mcp/server_example_test.go rename to examples/server/basic/main.go index e68dc308..54af6caa 100644 --- a/mcp/server_example_test.go +++ b/examples/server/basic/main.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. -package mcp_test +package main import ( "context" @@ -24,7 +24,7 @@ func SayHi(ctx context.Context, req *mcp.CallToolRequest, args SayHiParams) (*mc }, nil, nil } -func ExampleServer() { +func main() { ctx := context.Background() clientTransport, serverTransport := mcp.NewInMemoryTransports() @@ -56,19 +56,3 @@ func ExampleServer() { // Output: Hi user } - -// createSessions creates and connects an in-memory client and server session for testing purposes. -func createSessions(ctx context.Context) (*mcp.ClientSession, *mcp.ServerSession, *mcp.Server) { - server := mcp.NewServer(testImpl, nil) - client := mcp.NewClient(testImpl, nil) - serverTransport, clientTransport := mcp.NewInMemoryTransports() - serverSession, err := server.Connect(ctx, serverTransport, nil) - if err != nil { - log.Fatal(err) - } - clientSession, err := client.Connect(ctx, clientTransport, nil) - if err != nil { - log.Fatal(err) - } - return clientSession, serverSession, server -} diff --git a/mcp/client_list_test.go b/mcp/client_list_test.go index 1449076e..0183a733 100644 --- a/mcp/client_list_test.go +++ b/mcp/client_list_test.go @@ -7,6 +7,7 @@ package mcp_test import ( "context" "iter" + "log" "testing" "github.com/google/go-cmp/cmp" @@ -17,9 +18,19 @@ import ( func TestList(t *testing.T) { ctx := context.Background() - clientSession, serverSession, server := createSessions(ctx) - defer clientSession.Close() + server := mcp.NewServer(testImpl, nil) + client := mcp.NewClient(testImpl, nil) + serverTransport, clientTransport := mcp.NewInMemoryTransports() + serverSession, err := server.Connect(ctx, serverTransport, nil) + if err != nil { + log.Fatal(err) + } defer serverSession.Close() + clientSession, err := client.Connect(ctx, clientTransport, nil) + if err != nil { + log.Fatal(err) + } + defer clientSession.Close() t.Run("tools", func(t *testing.T) { var wantTools []*mcp.Tool diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index 6c3a1a76..98354a93 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -22,6 +22,18 @@ import ( const runAsServer = "_MCP_RUN_AS_SERVER" +type SayHiParams struct { + Name string `json:"name"` +} + +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args SayHiParams) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Hi " + args.Name}, + }, + }, nil, nil +} + func TestMain(m *testing.M) { // If the runAsServer variable is set, execute the relevant serverFunc // instead of running tests (aka the fork and exec trick). From c24d9856c455dd89d3e83a009838f68138ef5c5d Mon Sep 17 00:00:00 2001 From: Takeshi Yoneda Date: Tue, 26 Aug 2025 14:22:03 -0700 Subject: [PATCH 144/221] mcp/streamable: fixes broken DELETE request on connection close Signed-off-by: Takeshi Yoneda --- mcp/streamable.go | 2 +- mcp/streamable_test.go | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index f56b7084..a92ab494 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -133,7 +133,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "Accept must contain 'text/event-stream' for GET requests", http.StatusBadRequest) return } - } else if !jsonOK || !streamOK { + } else if (!jsonOK || !streamOK) && req.Method != http.MethodDelete { http.Error(w, "Accept must contain both 'application/json' and 'text/event-stream'", http.StatusBadRequest) return } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 52f61720..d08a0c77 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -687,6 +687,10 @@ func TestStreamableServerTransport(t *testing.T) { }, wantSessionID: true, }, + { + method: "DELETE", + wantStatusCode: http.StatusNoContent, + }, }, }, { @@ -945,7 +949,9 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, req.Header.Set("Mcp-Session-Id", sessionID) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") + if s.method != http.MethodDelete { // DELETE expects "No Content" response. + req.Header.Set("Accept", "application/json, text/event-stream") + } maps.Copy(req.Header, s.headers) resp, err := http.DefaultClient.Do(req) From a76bae3a11c008d59488083185d05a74b86f429c Mon Sep 17 00:00:00 2001 From: Takeshi Yoneda Date: Wed, 27 Aug 2025 12:00:37 -0700 Subject: [PATCH 145/221] address review comments Signed-off-by: Takeshi Yoneda --- mcp/streamable.go | 2 +- mcp/streamable_test.go | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index a92ab494..7d9faf04 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -133,7 +133,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "Accept must contain 'text/event-stream' for GET requests", http.StatusBadRequest) return } - } else if (!jsonOK || !streamOK) && req.Method != http.MethodDelete { + } else if (!jsonOK || !streamOK) && req.Method != http.MethodDelete { // TODO: consolidate with handling of http method below. http.Error(w, "Accept must contain both 'application/json' and 'text/event-stream'", http.StatusBadRequest) return } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index d08a0c77..c99ca782 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -690,6 +690,9 @@ func TestStreamableServerTransport(t *testing.T) { { method: "DELETE", wantStatusCode: http.StatusNoContent, + // Delete request expects 204 No Content with empty body. So override + // the default "accept: application/json, text/event-stream" header. + headers: map[string][]string{"Accept": nil}, }, }, }, @@ -949,9 +952,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, req.Header.Set("Mcp-Session-Id", sessionID) } req.Header.Set("Content-Type", "application/json") - if s.method != http.MethodDelete { // DELETE expects "No Content" response. - req.Header.Set("Accept", "application/json, text/event-stream") - } + req.Header.Set("Accept", "application/json, text/event-stream") maps.Copy(req.Header, s.headers) resp, err := http.DefaultClient.Do(req) From f6118aaace1777dd205f831775e76fa9b898faf6 Mon Sep 17 00:00:00 2001 From: "Huabing (Robin) Zhao" Date: Fri, 29 Aug 2025 22:06:23 +0800 Subject: [PATCH 146/221] mcp: fix the type of the Complete handler The Complete handler was returning an abstract Result type, rather than concrete CompleteResult type. Fixes #375 --- mcp/mcp_test.go | 33 +++++++++++++++++++++++++++++++++ mcp/server.go | 2 +- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 5485eff2..8529be5b 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -1725,3 +1725,36 @@ func TestPointerArgEquivalence(t *testing.T) { func ptr[T any](v T) *T { return &v } + +func TestComplete(t *testing.T) { + completionValues := []string{"python", "pytorch", "pyside"} + + serverOpts := &ServerOptions{ + CompletionHandler: func(_ context.Context, request *CompleteRequest) (*CompleteResult, error) { + return &CompleteResult{ + Completion: CompletionResultDetails{ + Values: completionValues, + }, + }, nil + }, + } + server := NewServer(testImpl, serverOpts) + cs, _ := basicClientServerConnection(t, nil, server, func(s *Server) {}) + result, err := cs.Complete(context.Background(), &CompleteParams{ + Argument: CompleteParamsArgument{ + Name: "language", + Value: "py", + }, + Ref: &CompleteReference{ + Type: "ref/prompt", + Name: "code_review", + }, + }) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(completionValues, result.Completion.Values); diff != "" { + t.Errorf("Complete() mismatch (-want +got):\n%s", diff) + } +} diff --git a/mcp/server.go b/mcp/server.go index c1440339..dd5d807c 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -407,7 +407,7 @@ func (s *Server) capabilities() *ServerCapabilities { return caps } -func (s *Server) complete(ctx context.Context, req *CompleteRequest) (Result, error) { +func (s *Server) complete(ctx context.Context, req *CompleteRequest) (*CompleteResult, error) { if s.opts.CompletionHandler == nil { return nil, jsonrpc2.ErrMethodNotFound } From ddaf35ed77479acefd0f39dd8ba3a877efe4b0f3 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Fri, 29 Aug 2025 12:54:51 -0400 Subject: [PATCH 147/221] mcp/streamable: use event store to fix unbounded memory issues (#335) This CL utilizes the event store to write outgoing messages and removes the unbounded outgoing data structure. It also adds a new interface [EventStore.Open] For #190 --- mcp/event.go | 27 +++++++++++- mcp/streamable.go | 106 ++++++++++++++++++++-------------------------- 2 files changed, 72 insertions(+), 61 deletions(-) diff --git a/mcp/event.go b/mcp/event.go index f4f4eeea..0dd8734b 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -153,6 +153,11 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { // // All of an EventStore's methods must be safe for use by multiple goroutines. type EventStore interface { + // Open prepares the event store for a given stream. It ensures that the + // underlying data structure for the stream is initialized, making it + // ready to store event streams. + Open(_ context.Context, sessionID string, streamID StreamID) error + // Append appends data for an outgoing event to given stream, which is part of the // given session. Append(_ context.Context, sessionID string, _ StreamID, data []byte) error @@ -162,6 +167,7 @@ type EventStore interface { // Once the iterator yields a non-nil error, it will stop. // After's iterator must return an error immediately if any data after index was // dropped; it must not return partial results. + // The stream must have been opened previously (see [EventStore.Open]). After(_ context.Context, sessionID string, _ StreamID, index int) iter.Seq2[[]byte, error] // SessionClosed informs the store that the given session is finished, along @@ -256,11 +262,20 @@ func NewMemoryEventStore(opts *MemoryEventStoreOptions) *MemoryEventStore { } } -// Append implements [EventStore.Append] by recording data in memory. -func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID StreamID, data []byte) error { +// Open implements [EventStore.Open]. It ensures that the underlying data +// structures for the given session are initialized and ready for use. +func (s *MemoryEventStore) Open(_ context.Context, sessionID string, streamID StreamID) error { s.mu.Lock() defer s.mu.Unlock() + s.init(sessionID, streamID) + return nil +} +// init is an internal helper function that ensures the nested map structure for a +// given sessionID and streamID exists, creating it if necessary. It returns the +// dataList associated with the specified IDs. +// Requires s.mu. +func (s *MemoryEventStore) init(sessionID string, streamID StreamID) *dataList { streamMap, ok := s.store[sessionID] if !ok { streamMap = make(map[StreamID]*dataList) @@ -271,6 +286,14 @@ func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID dl = &dataList{} streamMap[streamID] = dl } + return dl +} + +// Append implements [EventStore.Append] by recording data in memory. +func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID StreamID, data []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + dl := s.init(sessionID, streamID) // Purge before adding, so at least the current data item will be present. // (That could result in nBytes > maxBytes, but we'll live with that.) s.purge() diff --git a/mcp/streamable.go b/mcp/streamable.go index 7d9faf04..f56eabd9 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -401,7 +401,7 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp } // Connect implements the [Transport] interface. -func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) { +func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, error) { if t.connection != nil { return nil, fmt.Errorf("transport already connected") } @@ -415,13 +415,17 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) streams: make(map[StreamID]*stream), requestStreams: make(map[jsonrpc.ID]StreamID), } + if t.connection.eventStore == nil { + t.connection.eventStore = NewMemoryEventStore(nil) + } // Stream 0 corresponds to the hanging 'GET'. // // It is always text/event-stream, since it must carry arbitrarily many // messages. - t.connection.streams[""] = newStream("", false) - if t.connection.eventStore == nil { - t.connection.eventStore = NewMemoryEventStore(nil) + var err error + t.connection.streams[""], err = t.connection.newStream(ctx, "", false) + if err != nil { + return nil, err } return t.connection, nil } @@ -490,7 +494,7 @@ type stream struct { // that there are messages available to write into the HTTP response. // In addition, the presence of a channel guarantees that at most one HTTP response // can receive messages for a logical stream. After claiming the stream, incoming - // requests should read from outgoing, to ensure that no new messages are missed. + // requests should read from the event store, to ensure that no new messages are missed. // // To simplify locking, signal is an atomic. We need an atomic.Pointer, because // you can't set an atomic.Value to nil. @@ -502,22 +506,21 @@ type stream struct { // The following mutable fields are protected by the mutex of the containing // StreamableServerTransport. - // outgoing is the list of outgoing messages, enqueued by server methods that - // write notifications and responses, and dequeued by streamResponse. - outgoing [][]byte - // streamRequests is the set of unanswered incoming RPCs for the stream. // - // Requests persist until their response data has been added to outgoing. + // Requests persist until their response data has been added to the event store. requests map[jsonrpc.ID]struct{} } -func newStream(id StreamID, jsonResponse bool) *stream { +func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, jsonResponse bool) (*stream, error) { + if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil { + return nil, err + } return &stream{ id: id, jsonResponse: jsonResponse, requests: make(map[jsonrpc.ID]struct{}), - } + }, nil } func signalChanPtr() *chan struct{} { @@ -668,7 +671,11 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // notifications or server->client requests made in the course of handling. // Update accounting for this incoming payload. if len(requests) > 0 { - stream = newStream(StreamID(randText()), c.jsonResponse) + stream, err = c.newStream(req.Context(), StreamID(randText()), c.jsonResponse) + if err != nil { + http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) + return + } c.mu.Lock() c.streams[stream.id] = stream stream.requests = requests @@ -706,13 +713,13 @@ func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter var msgs []json.RawMessage ctx := req.Context() - for msg, ok := range c.messages(ctx, stream, false) { - if !ok { + for msg, err := range c.messages(ctx, stream, false, -1) { + if err != nil { if ctx.Err() != nil { w.WriteHeader(http.StatusNoContent) return } else { - http.Error(w, http.StatusText(http.StatusGone), http.StatusGone) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } } @@ -770,44 +777,18 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, } } - if lastIndex >= 0 { - // Resume. - for data, err := range c.eventStore.After(req.Context(), c.SessionID(), stream.id, lastIndex) { - if err != nil { - // TODO: reevaluate these status codes. - // Maybe distinguish between storage errors, which are 500s, and missing - // session or stream ID--can these arise from bad input? - status := http.StatusInternalServerError - if errors.Is(err, ErrEventsPurged) { - status = http.StatusInsufficientStorage - } - errorf(status, "failed to read events: %v", err) - return - } - // The iterator yields events beginning just after lastIndex, or it would have - // yielded an error. - if !write(data) { - return - } - } - } - // Repeatedly collect pending outgoing events and send them. ctx := req.Context() - for msg, ok := range c.messages(ctx, stream, persistent) { - if !ok { + for msg, err := range c.messages(ctx, stream, persistent, lastIndex) { + if err != nil { if ctx.Err() != nil && writes == 0 { // This probably doesn't matter, but respond with NoContent if the client disconnected. w.WriteHeader(http.StatusNoContent) } else { - errorf(http.StatusGone, "stream terminated") + errorf(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) } return } - if err := c.eventStore.Append(req.Context(), c.SessionID(), stream.id, msg); err != nil { - errorf(http.StatusInternalServerError, "storing event: %v", err.Error()) - return - } if !write(msg) { return } @@ -816,27 +797,33 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, // messages iterates over messages sent to the current stream. // +// persistent indicates if it is the main GET listener, which should never be +// terminated. +// lastIndex is the index of the last seen event, iteration begins at lastIndex+1. +// // The first iterated value is the received JSON message. The second iterated -// value is an OK value indicating whether the stream terminated normally. +// value is an error value indicating whether the stream terminated normally. +// Iteration ends at the first non-nil error. // // If the stream did not terminate normally, it is either because ctx was // cancelled, or the connection is closed: check the ctx.Err() to differentiate // these cases. -func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool) iter.Seq2[json.RawMessage, bool] { - return func(yield func(json.RawMessage, bool) bool) { +func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool, lastIndex int) iter.Seq2[json.RawMessage, error] { + return func(yield func(json.RawMessage, error) bool) { for { c.mu.Lock() - outgoing := stream.outgoing - stream.outgoing = nil nOutstanding := len(stream.requests) c.mu.Unlock() - - for _, data := range outgoing { - if !yield(data, true) { + for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, lastIndex) { + if err != nil { + yield(nil, err) return } + if !yield(data, nil) { + return + } + lastIndex++ } - // If all requests have been handled and replied to, we should terminate this connection. // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." // §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server @@ -850,13 +837,14 @@ func (c *streamableServerConn) messages(ctx context.Context, stream *stream, per case <-*stream.signal.Load(): // there are new outgoing messages // return to top of loop case <-c.done: // session is closed - yield(nil, false) + yield(nil, errors.New("session is closed")) return case <-ctx.Done(): - yield(nil, false) + yield(nil, ctx.Err()) return } } + } } @@ -963,9 +951,9 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e stream = c.streams[""] } - // TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == "" - // and the client never did a GET), then memory will grow without bound. Consider a mitigation. - stream.outgoing = append(stream.outgoing, data) + if err := c.eventStore.Append(ctx, c.SessionID(), stream.id, data); err != nil { + return fmt.Errorf("error storing event: %w", err) + } if isResponse { // Once we've put the reply on the queue, it's no longer outstanding. delete(stream.requests, forRequest) From 8f11a868987f493ed074b2297a4c7bf8a2003939 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Fri, 29 Aug 2025 15:28:05 -0400 Subject: [PATCH 148/221] examples/client: add a loadtest command Add a loadtest client example, to help confirm performance of our streamable transport implementation. For #190 --- examples/client/listfeatures/main.go | 6 +- examples/client/loadtest/main.go | 122 +++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 3 deletions(-) create mode 100644 examples/client/loadtest/main.go diff --git a/examples/client/listfeatures/main.go b/examples/client/listfeatures/main.go index 755a4f98..9d473f0b 100644 --- a/examples/client/listfeatures/main.go +++ b/examples/client/listfeatures/main.go @@ -31,10 +31,10 @@ func main() { flag.Parse() args := flag.Args() if len(args) == 0 { - fmt.Fprintf(os.Stderr, "Usage: listfeatures []") - fmt.Fprintf(os.Stderr, "List all features for a stdio MCP server") + fmt.Fprintln(os.Stderr, "Usage: listfeatures []") + fmt.Fprintln(os.Stderr, "List all features for a stdio MCP server") fmt.Fprintln(os.Stderr) - fmt.Fprintf(os.Stderr, "Example: listfeatures npx @modelcontextprotocol/server-everything") + fmt.Fprintln(os.Stderr, "Example:\n\tlistfeatures npx @modelcontextprotocol/server-everything") os.Exit(2) } diff --git a/examples/client/loadtest/main.go b/examples/client/loadtest/main.go new file mode 100644 index 00000000..2c6a5c03 --- /dev/null +++ b/examples/client/loadtest/main.go @@ -0,0 +1,122 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The load command load tests a streamable MCP server +// +// Usage: loadtest +// +// For example: +// +// loadtest -tool=greet -args='{"name": "foo"}' http://localhost:8080 +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "log" + "os" + "os/signal" + "sync" + "sync/atomic" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + duration = flag.Duration("duration", 1*time.Minute, "duration of the load test") + tool = flag.String("tool", "", "tool to call") + jsonArgs = flag.String("args", "", "JSON arguments to pass") + workers = flag.Int("workers", 10, "number of concurrent workers") + timeout = flag.Duration("timeout", 1*time.Second, "request timeout") + qps = flag.Int("qps", 100, "tool calls per second, per worker") + v = flag.Bool("v", false, "if set, enable verbose logging of results") +) + +func main() { + flag.Usage = func() { + out := flag.CommandLine.Output() + fmt.Fprintf(out, "Usage: loadtest [flags] ") + fmt.Fprintf(out, "Load test a streamable HTTP server (CTRL-C to end early)") + fmt.Fprintln(out) + fmt.Fprintf(out, "Example: loadtest -tool=greet -args='{\"name\": \"foo\"}' http://localhost:8080\n") + fmt.Fprintln(out) + fmt.Fprintln(out, "Flags:") + flag.PrintDefaults() + } + flag.Parse() + args := flag.Args() + if len(args) != 1 || *tool == "" { + flag.Usage() + os.Exit(2) + } + + parentCtx, cancel := context.WithTimeout(context.Background(), *duration) + defer cancel() + parentCtx, restoreSignal := signal.NotifyContext(parentCtx, os.Interrupt) + defer restoreSignal() + + var ( + start = time.Now() + success atomic.Int64 + failure atomic.Int64 + ) + + // Run the test. + var wg sync.WaitGroup + for range *workers { + wg.Add(1) + go func() { + defer wg.Done() + client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) + cs, err := client.Connect(parentCtx, &mcp.StreamableClientTransport{Endpoint: args[0]}, nil) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + ticker := time.NewTicker(1 * time.Second / time.Duration(*qps)) + defer ticker.Stop() + + for range ticker.C { + ctx, cancel := context.WithTimeout(parentCtx, *timeout) + defer cancel() + + res, err := cs.CallTool(ctx, &mcp.CallToolParams{Name: *tool, Arguments: json.RawMessage(*jsonArgs)}) + if err != nil { + if parentCtx.Err() != nil { + return // test ended + } + failure.Add(1) + if *v { + log.Printf("FAILURE: %v", err) + } + } else { + success.Add(1) + if *v { + data, err := json.Marshal(res) + if err != nil { + log.Fatalf("marshalling result: %v", err) + } + log.Printf("SUCCESS: %s", string(data)) + } + } + } + }() + } + wg.Wait() + restoreSignal() // call restore signal (redundantly) here to allow ctrl-c to work again + + // Print stats. + var ( + dur = time.Since(start) + succ = success.Load() + fail = failure.Load() + ) + fmt.Printf("Results (in %s):\n", dur) + fmt.Printf("\tsuccess: %d (%g QPS)\n", succ, float64(succ)/dur.Seconds()) + fmt.Printf("\tfailure: %d (%g QPS)\n", fail, float64(fail)/dur.Seconds()) +} From 7bfde44981f838739645a1a4df900c70c831be98 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sat, 30 Aug 2025 08:16:06 -0400 Subject: [PATCH 149/221] mcp: statically type the server-side tool params (#378) Introduce CallToolParamsRaw on the server side, so that tool handlers see in the type system that a tool's arguments are a json.RawMessage. Fixes #377. --- mcp/mcp_test.go | 5 ++--- mcp/protocol.go | 26 ++++++++++++-------------- mcp/requests.go | 2 +- mcp/server.go | 6 ++---- 4 files changed, 17 insertions(+), 22 deletions(-) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 8529be5b..a16ac838 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -494,7 +494,6 @@ func TestEndToEnd(t *testing.T) { if result.Action != "accept" { t.Errorf("got action %q, want %q", result.Action, "accept") } - }) // Disconnect. @@ -1638,7 +1637,7 @@ func TestPointerArgEquivalence(t *testing.T) { // // We handle a few different types of results, to assert they behave the // same in all cases. - AddTool(s, &Tool{Name: "pointer"}, func(_ context.Context, req *ServerRequest[*CallToolParams], in *input) (*CallToolResult, *output, error) { + AddTool(s, &Tool{Name: "pointer"}, func(_ context.Context, req *CallToolRequest, in *input) (*CallToolResult, *output, error) { switch in.In { case "": return nil, nil, fmt.Errorf("must provide input") @@ -1652,7 +1651,7 @@ func TestPointerArgEquivalence(t *testing.T) { panic("unreachable") } }) - AddTool(s, &Tool{Name: "nonpointer"}, func(_ context.Context, req *ServerRequest[*CallToolParams], in input) (*CallToolResult, output, error) { + AddTool(s, &Tool{Name: "nonpointer"}, func(_ context.Context, req *CallToolRequest, in input) (*CallToolResult, output, error) { switch in.In { case "": return nil, output{}, fmt.Errorf("must provide input") diff --git a/mcp/protocol.go b/mcp/protocol.go index 382f745f..27860659 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -40,6 +40,7 @@ type Annotations struct { Priority float64 `json:"priority,omitempty"` } +// CallToolParams is used by clients to call a tool. type CallToolParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. @@ -48,20 +49,13 @@ type CallToolParams struct { Arguments any `json:"arguments,omitempty"` } -// When unmarshalling CallToolParams on the server side, we need to delay unmarshaling of the arguments. -func (c *CallToolParams) UnmarshalJSON(data []byte) error { - var raw struct { - Meta `json:"_meta,omitempty"` - Name string `json:"name"` - RawArguments json.RawMessage `json:"arguments,omitempty"` - } - if err := json.Unmarshal(data, &raw); err != nil { - return err - } - c.Meta = raw.Meta - c.Name = raw.Name - c.Arguments = raw.RawArguments - return nil +// CallToolParamsRaw is passed to tool handlers on the server. +type CallToolParamsRaw struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments,omitempty"` } // The server's response to a tool call. @@ -115,6 +109,10 @@ func (x *CallToolParams) isParams() {} func (x *CallToolParams) GetProgressToken() any { return getProgressToken(x) } func (x *CallToolParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *CallToolParamsRaw) isParams() {} +func (x *CallToolParamsRaw) GetProgressToken() any { return getProgressToken(x) } +func (x *CallToolParamsRaw) SetProgressToken(t any) { setProgressToken(x, t) } + type CancelledParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. diff --git a/mcp/requests.go b/mcp/requests.go index 3afaac5e..82b700f5 100644 --- a/mcp/requests.go +++ b/mcp/requests.go @@ -7,7 +7,7 @@ package mcp type ( - CallToolRequest = ServerRequest[*CallToolParams] + CallToolRequest = ServerRequest[*CallToolParamsRaw] CompleteRequest = ServerRequest[*CompleteParams] GetPromptRequest = ServerRequest[*GetPromptParams] InitializedRequest = ServerRequest[*InitializedParams] diff --git a/mcp/server.go b/mcp/server.go index dd5d807c..c496b33a 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -9,7 +9,6 @@ import ( "context" "encoding/base64" "encoding/gob" - "encoding/json" "fmt" "iter" "maps" @@ -234,10 +233,9 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { // Unmarshal and validate args. - rawArgs := req.Params.Arguments.(json.RawMessage) var in In - if rawArgs != nil { - if err := unmarshalSchema(rawArgs, inputResolved, &in); err != nil { + if req.Params.Arguments != nil { + if err := unmarshalSchema(req.Params.Arguments, inputResolved, &in); err != nil { return nil, err } } From a617dcea31a776b8f9c54791ee96e1ed48c39e39 Mon Sep 17 00:00:00 2001 From: Adam Koszek Date: Sat, 30 Aug 2025 11:30:48 -0700 Subject: [PATCH 150/221] mcp/examples: HTTP server example with a simple client built-in (#168) I'm adding a practical example of an MCP with HTTP streaming: both client and server. Those are useful for testing real-world applications. --- examples/http/README.md | 66 +++++++++ examples/http/logging_middleware.go | 51 +++++++ examples/http/main.go | 203 ++++++++++++++++++++++++++++ 3 files changed, 320 insertions(+) create mode 100644 examples/http/README.md create mode 100644 examples/http/logging_middleware.go create mode 100644 examples/http/main.go diff --git a/examples/http/README.md b/examples/http/README.md new file mode 100644 index 00000000..16a4801d --- /dev/null +++ b/examples/http/README.md @@ -0,0 +1,66 @@ +# MCP HTTP Example + +This example demonstrates how to use the Model Context Protocol (MCP) over HTTP using the streamable transport. It includes both a server and client implementation. + +## Overview + +The example implements: +- A server that provides a `cityTime` tool +- A client that connects to the server, lists available tools, and calls the `cityTime` tool + +## Usage + +Start the Server + +```bash +go run main.go server +``` +This starts an MCP server on `http://localhost:8080` (default) that provides a `cityTime` tool. + +To run a client in another terminal: + +```bash +go run main.go client +``` + +The client will: +1. Connect to the server +2. List available tools +3. Call the `cityTime` tool for NYC, San Francisco, and Boston +4. Display the results + +At any given time you can pass a custom URL to the program to run it on a custom host/port: + +``` +go run main.go -host 0.0.0.0 -port 9000 server +``` + +## Testing with real-world MCP Clients + +Once the server is started, assuming it's the default +localhost:8080, you can try to add it to a popular MCP client: + + claude mcp add -t http timezone http://localhost:8080 + +Once added, Claude Code will be able to discover and use the `cityTime` tool provided by this server. + +In Claude Code: + + > what's the timezone + + ⏺ I'll get the current time in a major US city for you. + + ⏺ timezone - cityTime (MCP)(city: "nyc") + ⎿ The current time in New York City is 7:30:16 PM EDT on Wedn + esday, July 23, 2025 + + + ⏺ The current timezone is EDT (Eastern Daylight Time), and it's + 7:30 PM on Wednesday, July 23, 2025. + + > what timezones do you support? + + ⏺ The timezone tool supports three US cities: + - NYC (Eastern Time) + - SF (Pacific Time) + - Boston (Eastern Time) diff --git a/examples/http/logging_middleware.go b/examples/http/logging_middleware.go new file mode 100644 index 00000000..4266012c --- /dev/null +++ b/examples/http/logging_middleware.go @@ -0,0 +1,51 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "log" + "net/http" + "time" +) + +// responseWriter wraps http.ResponseWriter to capture the status code. +type responseWriter struct { + http.ResponseWriter + statusCode int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.statusCode = code + rw.ResponseWriter.WriteHeader(code) +} + +func loggingHandler(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // Create a response writer wrapper to capture status code. + wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} + + // Log request details. + log.Printf("[REQUEST] %s | %s | %s %s", + start.Format(time.RFC3339), + r.RemoteAddr, + r.Method, + r.URL.Path) + + // Call the actual handler. + handler.ServeHTTP(wrapped, r) + + // Log response details. + duration := time.Since(start) + log.Printf("[RESPONSE] %s | %s | %s %s | Status: %d | Duration: %v", + time.Now().Format(time.RFC3339), + r.RemoteAddr, + r.Method, + r.URL.Path, + wrapped.statusCode, + duration) + }) +} diff --git a/examples/http/main.go b/examples/http/main.go new file mode 100644 index 00000000..682dc8d8 --- /dev/null +++ b/examples/http/main.go @@ -0,0 +1,203 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "flag" + "fmt" + "log" + "net/http" + "os" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + host = flag.String("host", "localhost", "host to connect to/listen on") + port = flag.Int("port", 8000, "port number to connect to/listen on") + proto = flag.String("proto", "http", "if set, use as proto:// part of URL (ignored for server)") +) + +func main() { + out := flag.CommandLine.Output() + flag.Usage = func() { + fmt.Fprintf(out, "Usage: %s [-proto ] [-port ]\n\n", os.Args[0]) + fmt.Fprintf(out, "This program demonstrates MCP over HTTP using the streamable transport.\n") + fmt.Fprintf(out, "It can run as either a server or client.\n\n") + fmt.Fprintf(out, "Options:\n") + flag.PrintDefaults() + fmt.Fprintf(out, "\nExamples:\n") + fmt.Fprintf(out, " Run as server: %s server\n", os.Args[0]) + fmt.Fprintf(out, " Run as client: %s client\n", os.Args[0]) + fmt.Fprintf(out, " Custom host/port: %s -port 9000 -host 0.0.0.0 server\n", os.Args[0]) + os.Exit(1) + } + flag.Parse() + + if flag.NArg() != 1 { + fmt.Fprintf(out, "Error: Must specify 'client' or 'server' as first argument\n") + flag.Usage() + } + mode := flag.Arg(0) + + switch mode { + case "server": + if *proto != "http" { + log.Fatalf("Server only works with 'http' (you passed proto=%s)", *proto) + } + runServer(fmt.Sprintf("%s:%d", *host, *port)) + case "client": + runClient(fmt.Sprintf("%s://%s:%d", *proto, *host, *port)) + default: + fmt.Fprintf(os.Stderr, "Error: Invalid mode '%s'. Must be 'client' or 'server'\n\n", mode) + flag.Usage() + } +} + +// GetTimeParams defines the parameters for the cityTime tool. +type GetTimeParams struct { + City string `json:"city" jsonschema:"City to get time for (nyc, sf, or boston)"` +} + +// getTime implements the tool that returns the current time for a given city. +func getTime(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[GetTimeParams]) (*mcp.CallToolResultFor[any], error) { + // Define time zones for each city + locations := map[string]string{ + "nyc": "America/New_York", + "sf": "America/Los_Angeles", + "boston": "America/New_York", + } + + city := params.Arguments.City + if city == "" { + city = "nyc" // Default to NYC + } + + // Get the timezone. + tzName, ok := locations[city] + if !ok { + return nil, fmt.Errorf("unknown city: %s", city) + } + + // Load the location. + loc, err := time.LoadLocation(tzName) + if err != nil { + return nil, fmt.Errorf("failed to load timezone: %w", err) + } + + // Get current time in that location. + now := time.Now().In(loc) + + // Format the response. + cityNames := map[string]string{ + "nyc": "New York City", + "sf": "San Francisco", + "boston": "Boston", + } + + response := fmt.Sprintf("The current time in %s is %s", + cityNames[city], + now.Format(time.RFC3339)) + + return &mcp.CallToolResultFor[any]{ + Content: []mcp.Content{ + &mcp.TextContent{Text: response}, + }, + }, nil +} + +func runServer(url string) { + // Create an MCP server. + server := mcp.NewServer(&mcp.Implementation{ + Name: "time-server", + Version: "1.0.0", + }, nil) + + // Add the cityTime tool. + mcp.AddTool(server, &mcp.Tool{ + Name: "cityTime", + Description: "Get the current time in NYC, San Francisco, or Boston", + }, getTime) + + // Create the streamable HTTP handler. + handler := mcp.NewStreamableHTTPHandler(func(req *http.Request) *mcp.Server { + return server + }, nil) + + handlerWithLogging := loggingHandler(handler) + + log.Printf("MCP server listening on %s", url) + log.Printf("Available tool: cityTime (cities: nyc, sf, boston)") + + // Start the HTTP server with logging handler. + if err := http.ListenAndServe(url, handlerWithLogging); err != nil { + log.Fatalf("Server failed: %v", err) + } +} + +func runClient(url string) { + ctx := context.Background() + + // Create the URL for the server. + log.Printf("Connecting to MCP server at %s", url) + + // Create a streamable client transport. + transport := mcp.NewStreamableClientTransport(url, nil) + + // Create an MCP client. + client := mcp.NewClient(&mcp.Implementation{ + Name: "time-client", + Version: "1.0.0", + }, nil) + + // Connect to the server. + session, err := client.Connect(ctx, transport) + if err != nil { + log.Fatalf("Failed to connect: %v", err) + } + defer session.Close() + + log.Printf("Connected to server (session ID: %s)", session.ID()) + + // First, list available tools. + log.Println("Listing available tools...") + toolsResult, err := session.ListTools(ctx, nil) + if err != nil { + log.Fatalf("Failed to list tools: %v", err) + } + + for _, tool := range toolsResult.Tools { + log.Printf(" - %s: %s\n", tool.Name, tool.Description) + } + + // Call the cityTime tool for each city. + cities := []string{"nyc", "sf", "boston"} + + log.Println("Getting time for each city...") + for _, city := range cities { + // Call the tool. + result, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "cityTime", + Arguments: map[string]any{ + "city": city, + }, + }) + if err != nil { + log.Printf("Failed to get time for %s: %v\n", city, err) + continue + } + + // Print the result. + for _, content := range result.Content { + if textContent, ok := content.(*mcp.TextContent); ok { + log.Printf(" %s", textContent.Text) + } + } + } + + log.Println("Client completed successfully") +} From be0a00cd0a57c30afdb9aac035b495744cc0c9f3 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 1 Sep 2025 09:24:30 -0400 Subject: [PATCH 151/221] examples/http: update to v0.3.0 Fix build breakage. Fixes #383. --- examples/http/main.go | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/examples/http/main.go b/examples/http/main.go index 682dc8d8..188674ae 100644 --- a/examples/http/main.go +++ b/examples/http/main.go @@ -64,7 +64,7 @@ type GetTimeParams struct { } // getTime implements the tool that returns the current time for a given city. -func getTime(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[GetTimeParams]) (*mcp.CallToolResultFor[any], error) { +func getTime(ctx context.Context, req *mcp.CallToolRequest, params *GetTimeParams) (*mcp.CallToolResult, any, error) { // Define time zones for each city locations := map[string]string{ "nyc": "America/New_York", @@ -72,7 +72,7 @@ func getTime(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolPar "boston": "America/New_York", } - city := params.Arguments.City + city := params.City if city == "" { city = "nyc" // Default to NYC } @@ -80,13 +80,13 @@ func getTime(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolPar // Get the timezone. tzName, ok := locations[city] if !ok { - return nil, fmt.Errorf("unknown city: %s", city) + return nil, nil, fmt.Errorf("unknown city: %s", city) } // Load the location. loc, err := time.LoadLocation(tzName) if err != nil { - return nil, fmt.Errorf("failed to load timezone: %w", err) + return nil, nil, fmt.Errorf("failed to load timezone: %w", err) } // Get current time in that location. @@ -103,11 +103,11 @@ func getTime(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolPar cityNames[city], now.Format(time.RFC3339)) - return &mcp.CallToolResultFor[any]{ + return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: response}, }, - }, nil + }, nil, nil } func runServer(url string) { @@ -145,9 +145,6 @@ func runClient(url string) { // Create the URL for the server. log.Printf("Connecting to MCP server at %s", url) - // Create a streamable client transport. - transport := mcp.NewStreamableClientTransport(url, nil) - // Create an MCP client. client := mcp.NewClient(&mcp.Implementation{ Name: "time-client", @@ -155,7 +152,7 @@ func runClient(url string) { }, nil) // Connect to the server. - session, err := client.Connect(ctx, transport) + session, err := client.Connect(ctx, &mcp.StreamableClientTransport{Endpoint: url}, nil) if err != nil { log.Fatalf("Failed to connect: %v", err) } From a3935c6be225392d4e9ba42a067090e325b0c685 Mon Sep 17 00:00:00 2001 From: Shusaku Yasoda <136243871+yasomaru@users.noreply.github.com> Date: Wed, 3 Sep 2025 00:49:46 +0900 Subject: [PATCH 152/221] examples/server: add example MCP server with JWT and API key (#339) This commit introduces a new example demonstrating the integration of authentication middleware with an MCP server. The server supports both JWT token and API key authentication, along with scope-based access control for various MCP tools. Key features include token generation endpoints, in-memory API key storage, and a health check endpoint. New files added: - `main.go`: Implements the MCP server and authentication logic. - `go.mod` and `go.sum`: Manage dependencies for the project. - `README.md`: Provides setup instructions, available endpoints, and example usage. This example serves as a reference for implementing secure access to MCP tools. Fixes #330 --- examples/server/auth-middleware/README.md | 247 ++++++++++++++ examples/server/auth-middleware/go.mod | 15 + examples/server/auth-middleware/go.sum | 10 + examples/server/auth-middleware/main.go | 392 ++++++++++++++++++++++ 4 files changed, 664 insertions(+) create mode 100644 examples/server/auth-middleware/README.md create mode 100644 examples/server/auth-middleware/go.mod create mode 100644 examples/server/auth-middleware/go.sum create mode 100644 examples/server/auth-middleware/main.go diff --git a/examples/server/auth-middleware/README.md b/examples/server/auth-middleware/README.md new file mode 100644 index 00000000..913022ff --- /dev/null +++ b/examples/server/auth-middleware/README.md @@ -0,0 +1,247 @@ +# MCP Server with Auth Middleware + +This example demonstrates how to integrate the Go MCP SDK's `auth.RequireBearerToken` middleware with an MCP server to provide authenticated access to MCP tools and resources. + +## Features + +The server provides authentication and authorization capabilities for MCP tools: + +### 1. Authentication Methods + +- **JWT Token Authentication**: JSON Web Token-based authentication +- **API Key Authentication**: API key-based authentication +- **Scope-based Access Control**: Permission-based access to MCP tools + +### 2. MCP Integration + +- **Authenticated MCP Tools**: Tools that require authentication and check permissions +- **Token Generation**: Utility endpoints for generating test tokens +- **Middleware Integration**: Seamless integration with MCP server handlers + +## Setup + +```bash +cd examples/server/auth-middleware +go mod tidy +go run main.go +``` + +## Testing + +```bash +# Run all tests +go test -v + +# Run benchmark tests +go test -bench=. + +# Generate coverage report +go test -cover +``` + +## Endpoints + +### Public Endpoints (No Authentication Required) + +- `GET /health` - Health check + +### MCP Endpoints (Authentication Required) + +- `POST /mcp/jwt` - MCP server with JWT authentication +- `POST /mcp/apikey` - MCP server with API key authentication + +### Utility Endpoints + +- `GET /generate-token` - Generate JWT token +- `POST /generate-api-key` - Generate API key + +## Available MCP Tools + +The server provides three authenticated MCP tools: + +### 1. Say Hi (`say_hi`) + +A simple greeting tool that requires authentication. + +**Parameters:** +- None required + +**Required Scopes:** +- Any authenticated user + +### 2. Get User Info (`get_user_info`) + +Retrieves user information based on the provided user ID. + +**Parameters:** +- `user_id` (string): The user ID to get information for + +**Required Scopes:** +- `read` permission + +### 3. Create Resource (`create_resource`) + +Creates a new resource with the provided details. + +**Parameters:** +- `name` (string): The name of the resource +- `description` (string): The description of the resource +- `content` (string): The content of the resource + +**Required Scopes:** +- `write` permission + +## Example Usage + +### 1. Generating JWT Token and Using MCP Tools + +```bash +# Generate a token +curl 'http://localhost:8080/generate-token?user_id=alice&scopes=read,write' + +# Use MCP tool with JWT authentication +curl -H 'Authorization: Bearer ' \ + -H 'Content-Type: application/json' \ + -d '{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"say_hi","arguments":{}}}' \ + http://localhost:8080/mcp/jwt +``` + +### 2. Generating API Key and Using MCP Tools + +```bash +# Generate an API key +curl -X POST 'http://localhost:8080/generate-api-key?user_id=bob&scopes=read' + +# Use MCP tool with API key authentication +curl -H 'Authorization: Bearer ' \ + -H 'Content-Type: application/json' \ + -d '{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"get_user_info","arguments":{"user_id":"test"}}}' \ + http://localhost:8080/mcp/apikey +``` + +### 3. Testing Scope Restrictions + +```bash +# Access MCP tool requiring write scope +curl -H 'Authorization: Bearer ' \ + -H 'Content-Type: application/json' \ + -d '{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"create_resource","arguments":{"name":"test","description":"test resource","content":"test content"}}}' \ + http://localhost:8080/mcp/jwt +``` + +## Core Concepts + +### Authentication Integration + +This example demonstrates how to integrate `auth.RequireBearerToken` middleware with an MCP server to provide authenticated access. The MCP server operates as an HTTP handler protected by authentication middleware. + +### Key Features + +1. **MCP Server Integration**: Create MCP server using `mcp.NewServer` +2. **Authentication Middleware**: Protect MCP handlers with `auth.RequireBearerToken` +3. **Token Verification**: Validate tokens using provided `TokenVerifier` functions +4. **Scope Checking**: Verify required permissions (scopes) are present +5. **Expiration Validation**: Check that tokens haven't expired +6. **Context Injection**: Add verified token information to request context +7. **Authenticated MCP Tools**: Tools that operate based on authentication information +8. **Error Handling**: Return appropriate HTTP status codes and error messages on authentication failure + +### Implementation + +```go +// Create MCP server +server := mcp.NewServer(&mcp.Implementation{Name: "authenticated-mcp-server"}, nil) + +// Create authentication middleware +authMiddleware := auth.RequireBearerToken(verifier, &auth.RequireBearerTokenOptions{ + Scopes: []string{"read", "write"}, +}) + +// Create MCP handler +handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return server +}, nil) + +// Apply authentication middleware to MCP handler +authenticatedHandler := authMiddleware(customMiddleware(handler)) +``` + +### Parameters + +- **verifier**: Function to verify tokens (`TokenVerifier` type) +- **opts**: Authentication options + - `Scopes`: List of required permissions + - `ResourceMetadataURL`: OAuth 2.0 resource metadata URL + +### Error Responses + +- **401 Unauthorized**: Token is invalid, expired, or missing +- **403 Forbidden**: Required scopes are insufficient +- **WWW-Authenticate Header**: Included when resource metadata URL is configured + +## Implementation Details + +### 1. TokenVerifier Implementation + +```go +func jwtVerifier(ctx context.Context, tokenString string) (*auth.TokenInfo, error) { + // JWT token verification logic + // On success: Return TokenInfo + // On failure: Return auth.ErrInvalidToken +} +``` + +### 2. Using Authentication Information in MCP Tools + +```go +// Get authentication information in MCP tool +func MyTool(ctx context.Context, req *mcp.CallToolRequest, args MyArgs) (*mcp.CallToolResult, any, error) { + // Extract authentication info from request + userInfo := req.Extra.TokenInfo + + // Check scopes + if !slices.Contains(userInfo.Scopes, "read") { + return nil, nil, fmt.Errorf("insufficient permissions: read scope required") + } + + // Execute tool logic + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Tool executed successfully"}, + }, + }, nil, nil +} +``` + +### 3. Middleware Composition + +```go +// Combine authentication middleware with custom middleware +authenticatedHandler := authMiddleware(customMiddleware(mcpHandler)) +``` + +## Security Best Practices + +1. **Environment Variables**: Use environment variables for JWT secrets in production +2. **Database Storage**: Store API keys in a database +3. **HTTPS Usage**: Always use HTTPS in production environments +4. **Token Expiration**: Set appropriate token expiration times +5. **Principle of Least Privilege**: Grant only the minimum required scopes + +## Use Cases + +**Ideal for:** + +- MCP servers requiring authentication and authorization +- Applications needing scope-based access control +- Systems requiring both JWT and API key authentication +- Projects needing secure MCP tool access +- Scenarios requiring audit trails and permission management + +**Examples:** + +- Enterprise MCP servers with user management +- Multi-tenant MCP applications +- Secure API gateways with MCP integration +- Development environments with authentication requirements +- Production systems requiring fine-grained access control diff --git a/examples/server/auth-middleware/go.mod b/examples/server/auth-middleware/go.mod new file mode 100644 index 00000000..402c0aae --- /dev/null +++ b/examples/server/auth-middleware/go.mod @@ -0,0 +1,15 @@ +module auth-middleware-example + +go 1.23.0 + +require ( + github.com/golang-jwt/jwt/v5 v5.2.2 + github.com/modelcontextprotocol/go-sdk v0.3.0 +) + +require ( + github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect +) + +replace github.com/modelcontextprotocol/go-sdk => ../../../ diff --git a/examples/server/auth-middleware/go.sum b/examples/server/auth-middleware/go.sum new file mode 100644 index 00000000..eea5bdb5 --- /dev/null +++ b/examples/server/auth-middleware/go.sum @@ -0,0 +1,10 @@ +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76 h1:mBlBwtDebdDYr+zdop8N62a44g+Nbv7o2KjWyS1deR4= +github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= diff --git a/examples/server/auth-middleware/main.go b/examples/server/auth-middleware/main.go new file mode 100644 index 00000000..f472b760 --- /dev/null +++ b/examples/server/auth-middleware/main.go @@ -0,0 +1,392 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "flag" + "fmt" + "log" + "net/http" + "slices" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// This example demonstrates how to integrate auth.RequireBearerToken middleware +// with an MCP server to provide authenticated access to MCP tools and resources. + +var httpAddr = flag.String("http", ":8080", "HTTP address to listen on") + +// JWTClaims represents the claims in our JWT tokens. +// In a real application, you would include additional claims like issuer, audience, etc. +type JWTClaims struct { + UserID string `json:"user_id"` // User identifier + Scopes []string `json:"scopes"` // Permissions/roles for the user + jwt.RegisteredClaims +} + +// APIKey represents an API key with associated scopes. +// In production, this would be stored in a database with additional metadata. +type APIKey struct { + Key string `json:"key"` // The actual API key value + UserID string `json:"user_id"` // User identifier + Scopes []string `json:"scopes"` // Permissions/roles for this key +} + +// In-memory storage for API keys (in production, use a database). +// This is for demonstration purposes only. +var apiKeys = map[string]*APIKey{ + "sk-1234567890abcdef": { + Key: "sk-1234567890abcdef", + UserID: "user1", + Scopes: []string{"read", "write"}, + }, + "sk-abcdef1234567890": { + Key: "sk-abcdef1234567890", + UserID: "user2", + Scopes: []string{"read"}, + }, +} + +// JWT secret (in production, use environment variables). +// This should be a strong, randomly generated secret in real applications. +var jwtSecret = []byte("your-secret-key") + +// generateToken creates a JWT token for testing purposes. +// In a real application, this would be handled by your authentication service. +func generateToken(userID string, scopes []string, expiresIn time.Duration) (string, error) { + // Create JWT claims with user information and scopes. + claims := JWTClaims{ + UserID: userID, + Scopes: scopes, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiresIn)), // Token expiration + IssuedAt: jwt.NewNumericDate(time.Now()), // Token issuance time + NotBefore: jwt.NewNumericDate(time.Now()), // Token validity start time + }, + } + + // Create and sign the JWT token. + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString(jwtSecret) +} + +// verifyJWT verifies JWT tokens and returns TokenInfo for the auth middleware. +// This function implements the TokenVerifier interface required by auth.RequireBearerToken. +func verifyJWT(ctx context.Context, tokenString string) (*auth.TokenInfo, error) { + // Parse and validate the JWT token. + token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) { + // Verify the signing method is HMAC. + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return jwtSecret, nil + }) + + if err != nil { + // Return standard error for invalid tokens. + return nil, fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) + } + + // Extract claims and verify token validity. + if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid { + return &auth.TokenInfo{ + Scopes: claims.Scopes, // User permissions + Expiration: claims.ExpiresAt.Time, // Token expiration time + }, nil + } + + return nil, fmt.Errorf("%w: invalid token claims", auth.ErrInvalidToken) +} + +// verifyAPIKey verifies API keys and returns TokenInfo for the auth middleware. +// This function implements the TokenVerifier interface required by auth.RequireBearerToken. +func verifyAPIKey(ctx context.Context, apiKey string) (*auth.TokenInfo, error) { + // Look up the API key in our storage. + key, exists := apiKeys[apiKey] + if !exists { + return nil, fmt.Errorf("%w: API key not found", auth.ErrInvalidToken) + } + + // API keys don't expire in this example, but you could add expiration logic here. + // For demonstration, we set a 24-hour expiration. + return &auth.TokenInfo{ + Scopes: key.Scopes, // User permissions + Expiration: time.Now().Add(24 * time.Hour), // 24 hour expiration + }, nil +} + +// MCP Tool Arguments +type getUserInfoArgs struct { + UserID string `json:"user_id" jsonschema:"the user ID to get information for"` +} + +type createResourceArgs struct { + Name string `json:"name" jsonschema:"the name of the resource"` + Description string `json:"description" jsonschema:"the description of the resource"` + Content string `json:"content" jsonschema:"the content of the resource"` +} + +// SayHi is a simple MCP tool that requires authentication +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args struct{}) (*mcp.CallToolResult, any, error) { + // Extract user information from request (v0.3.0+) + userInfo := req.Extra.TokenInfo + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: fmt.Sprintf("Hello! You have scopes: %v", userInfo.Scopes)}, + }, + }, nil, nil +} + +// GetUserInfo is an MCP tool that requires read scope +func GetUserInfo(ctx context.Context, req *mcp.CallToolRequest, args getUserInfoArgs) (*mcp.CallToolResult, any, error) { + // Extract user information from request (v0.3.0+) + userInfo := req.Extra.TokenInfo + + // Check if user has read scope. + if !slices.Contains(userInfo.Scopes, "read") { + return nil, nil, fmt.Errorf("insufficient permissions: read scope required") + } + + userData := map[string]any{ + "requested_user_id": args.UserID, + "your_scopes": userInfo.Scopes, + "message": "User information retrieved successfully", + } + + userDataJSON, err := json.Marshal(userData) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal user data: %w", err) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: string(userDataJSON)}, + }, + }, nil, nil +} + +// CreateResource is an MCP tool that requires write scope +func CreateResource(ctx context.Context, req *mcp.CallToolRequest, args createResourceArgs) (*mcp.CallToolResult, any, error) { + // Extract user information from request (v0.3.0+) + userInfo := req.Extra.TokenInfo + + // Check if user has write scope. + if !slices.Contains(userInfo.Scopes, "write") { + return nil, nil, fmt.Errorf("insufficient permissions: write scope required") + } + + resourceInfo := map[string]any{ + "name": args.Name, + "description": args.Description, + "content": args.Content, + "created_by": "authenticated_user", + "created_at": time.Now().Format(time.RFC3339), + } + + resourceInfoJSON, err := json.Marshal(resourceInfo) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal resource info: %w", err) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: fmt.Sprintf("Resource created successfully: %s", string(resourceInfoJSON))}, + }, + }, nil, nil +} + +// authMiddleware extracts token information and adds it to the context +func authMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // In a real application, you would extract token info from the auth middleware's context + // For this example, we simulate the token info that would be available + ctx := context.WithValue(r.Context(), "user_info", &auth.TokenInfo{ + Scopes: []string{"read", "write"}, + Expiration: time.Now().Add(time.Hour), + }) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// createMCPServer creates an MCP server with authentication-aware tools +func createMCPServer() *mcp.Server { + server := mcp.NewServer(&mcp.Implementation{Name: "authenticated-mcp-server"}, nil) + + // Add tools that require authentication. + mcp.AddTool(server, &mcp.Tool{ + Name: "say_hi", + Description: "A simple greeting tool that requires authentication", + }, SayHi) + + mcp.AddTool(server, &mcp.Tool{ + Name: "get_user_info", + Description: "Get user information (requires read scope)", + }, GetUserInfo) + + mcp.AddTool(server, &mcp.Tool{ + Name: "create_resource", + Description: "Create a new resource (requires write scope)", + }, CreateResource) + + return server +} + +func main() { + flag.Parse() + + // Create the MCP server. + server := createMCPServer() + + // Create authentication middleware. + jwtAuth := auth.RequireBearerToken(verifyJWT, &auth.RequireBearerTokenOptions{ + Scopes: []string{"read"}, // Require "read" permission + }) + + apiKeyAuth := auth.RequireBearerToken(verifyAPIKey, &auth.RequireBearerTokenOptions{ + Scopes: []string{"read"}, // Require "read" permission + }) + + // Create HTTP handler with authentication. + handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return server + }, nil) + + // Apply authentication middleware to the MCP handler. + authenticatedHandler := jwtAuth(authMiddleware(handler)) + apiKeyHandler := apiKeyAuth(authMiddleware(handler)) + + // Create router for different authentication methods. + http.HandleFunc("/mcp/jwt", authenticatedHandler.ServeHTTP) + http.HandleFunc("/mcp/apikey", apiKeyHandler.ServeHTTP) + + // Add utility endpoints for token generation. + http.HandleFunc("/generate-token", func(w http.ResponseWriter, r *http.Request) { + // Get user ID from query parameters (default: "test-user"). + userID := r.URL.Query().Get("user_id") + if userID == "" { + userID = "test-user" + } + + // Get scopes from query parameters (default: ["read", "write"]). + scopes := strings.Split(r.URL.Query().Get("scopes"), ",") + if len(scopes) == 1 && scopes[0] == "" { + scopes = []string{"read", "write"} + } + + // Get expiration time from query parameters (default: 1 hour). + expiresIn := 1 * time.Hour + if expStr := r.URL.Query().Get("expires_in"); expStr != "" { + if exp, err := time.ParseDuration(expStr); err == nil { + expiresIn = exp + } + } + + // Generate the JWT token. + token, err := generateToken(userID, scopes, expiresIn) + if err != nil { + http.Error(w, "Failed to generate token", http.StatusInternalServerError) + return + } + + // Return the generated token. + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "token": token, + "type": "Bearer", + }) + }) + + http.HandleFunc("/generate-api-key", func(w http.ResponseWriter, r *http.Request) { + // Generate a random API key using cryptographically secure random bytes. + bytes := make([]byte, 16) + if _, err := rand.Read(bytes); err != nil { + http.Error(w, "Failed to generate random bytes", http.StatusInternalServerError) + return + } + apiKey := "sk-" + base64.URLEncoding.EncodeToString(bytes) + + // Get user ID from query parameters (default: "test-user"). + userID := r.URL.Query().Get("user_id") + if userID == "" { + userID = "test-user" + } + + // Get scopes from query parameters (default: ["read"]). + scopes := strings.Split(r.URL.Query().Get("scopes"), ",") + if len(scopes) == 1 && scopes[0] == "" { + scopes = []string{"read"} + } + + // Store the new API key in our in-memory storage. + // In production, this would be stored in a database. + apiKeys[apiKey] = &APIKey{ + Key: apiKey, + UserID: userID, + Scopes: scopes, + } + + // Return the generated API key. + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "api_key": apiKey, + "type": "Bearer", + }) + }) + + // Health check endpoint. + http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "status": "healthy", + "time": time.Now().Format(time.RFC3339), + }) + }) + + // Start the HTTP server. + log.Println("Authenticated MCP Server") + log.Println("========================") + log.Println("Server starting on", *httpAddr) + log.Println() + log.Println("Available endpoints:") + log.Println(" GET /health - Health check (no auth)") + log.Println(" GET /generate-token - Generate JWT token") + log.Println(" POST /generate-api-key - Generate API key") + log.Println(" POST /mcp/jwt - MCP endpoint (JWT auth)") + log.Println(" POST /mcp/apikey - MCP endpoint (API key auth)") + log.Println() + log.Println("Available MCP Tools:") + log.Println(" - say_hi - Simple greeting (any auth)") + log.Println(" - get_user_info - Get user info (read scope)") + log.Println(" - create_resource - Create resource (write scope)") + log.Println() + log.Println("Example usage:") + log.Println(" # Generate a token") + log.Println(" curl 'http://localhost:8080/generate-token?user_id=alice&scopes=read,write'") + log.Println() + log.Println(" # Use MCP with JWT authentication") + log.Println(" curl -H 'Authorization: Bearer ' -H 'Content-Type: application/json' \\") + log.Println(" -d '{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"tools/call\",\"params\":{\"name\":\"say_hi\",\"arguments\":{}}}' \\") + log.Println(" http://localhost:8080/mcp/jwt") + log.Println() + log.Println(" # Generate an API key") + log.Println(" curl -X POST 'http://localhost:8080/generate-api-key?user_id=bob&scopes=read'") + log.Println() + log.Println(" # Use MCP with API key authentication") + log.Println(" curl -H 'Authorization: Bearer ' -H 'Content-Type: application/json' \\") + log.Println(" -d '{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"tools/call\",\"params\":{\"name\":\"get_user_info\",\"arguments\":{\"user_id\":\"test\"}}}' \\") + log.Println(" http://localhost:8080/mcp/apikey") + + log.Fatal(http.ListenAndServe(*httpAddr, nil)) +} From 24fc13eb88446dbda49c3cbb1b990a50fd096ffd Mon Sep 17 00:00:00 2001 From: Brad Hoekstra Date: Tue, 2 Sep 2025 14:19:53 -0400 Subject: [PATCH 153/221] cleanup: remove reference to deleted jsonschema package Signed-off-by: Brad Hoekstra --- README.md | 6 +----- internal/readme/README.src.md | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index bffd48aa..126aa52e 100644 --- a/README.md +++ b/README.md @@ -33,16 +33,12 @@ open-ended discussion. See [CONTRIBUTING.md](/CONTRIBUTING.md) for details. ## Package documentation -The SDK consists of three importable packages: +The SDK consists of two importable packages: - The [`github.com/modelcontextprotocol/go-sdk/mcp`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp) package defines the primary APIs for constructing and using MCP clients and servers. -- The - [`github.com/modelcontextprotocol/go-sdk/jsonschema`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonschema) - package provides an implementation of [JSON - Schema](https://json-schema.org/), used for MCP tool input and output schema. - The [`github.com/modelcontextprotocol/go-sdk/jsonrpc`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonrpc) package is for users implementing their own transports. diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index bafcbc73..40e0aa19 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -32,16 +32,12 @@ open-ended discussion. See [CONTRIBUTING.md](/CONTRIBUTING.md) for details. ## Package documentation -The SDK consists of three importable packages: +The SDK consists of two importable packages: - The [`github.com/modelcontextprotocol/go-sdk/mcp`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp) package defines the primary APIs for constructing and using MCP clients and servers. -- The - [`github.com/modelcontextprotocol/go-sdk/jsonschema`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonschema) - package provides an implementation of [JSON - Schema](https://json-schema.org/), used for MCP tool input and output schema. - The [`github.com/modelcontextprotocol/go-sdk/jsonrpc`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonrpc) package is for users implementing their own transports. From 07b65d76a7fdb710e757f26e2c03418609bc1b4b Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Tue, 2 Sep 2025 14:32:12 -0400 Subject: [PATCH 154/221] examples/server/distributed: a stateless distributed server Add an example of a server that proxies HTTP requests to multiple stateless streamable MCP backends. For #148 --- examples/client/loadtest/main.go | 12 +- examples/server/distributed/main.go | 166 ++++++++++++++++++++++++++++ examples/server/everything/main.go | 28 ++++- 3 files changed, 198 insertions(+), 8 deletions(-) create mode 100644 examples/server/distributed/main.go diff --git a/examples/client/loadtest/main.go b/examples/client/loadtest/main.go index 2c6a5c03..d5a04c2e 100644 --- a/examples/client/loadtest/main.go +++ b/examples/client/loadtest/main.go @@ -33,7 +33,7 @@ var ( workers = flag.Int("workers", 10, "number of concurrent workers") timeout = flag.Duration("timeout", 1*time.Second, "request timeout") qps = flag.Int("qps", 100, "tool calls per second, per worker") - v = flag.Bool("v", false, "if set, enable verbose logging of results") + verbose = flag.Bool("v", false, "if set, enable verbose logging") ) func main() { @@ -56,8 +56,8 @@ func main() { parentCtx, cancel := context.WithTimeout(context.Background(), *duration) defer cancel() - parentCtx, restoreSignal := signal.NotifyContext(parentCtx, os.Interrupt) - defer restoreSignal() + parentCtx, stop := signal.NotifyContext(parentCtx, os.Interrupt) + defer stop() var ( start = time.Now() @@ -91,12 +91,12 @@ func main() { return // test ended } failure.Add(1) - if *v { + if *verbose { log.Printf("FAILURE: %v", err) } } else { success.Add(1) - if *v { + if *verbose { data, err := json.Marshal(res) if err != nil { log.Fatalf("marshalling result: %v", err) @@ -108,7 +108,7 @@ func main() { }() } wg.Wait() - restoreSignal() // call restore signal (redundantly) here to allow ctrl-c to work again + stop() // restore the interrupt signal // Print stats. var ( diff --git a/examples/server/distributed/main.go b/examples/server/distributed/main.go new file mode 100644 index 00000000..43d959e6 --- /dev/null +++ b/examples/server/distributed/main.go @@ -0,0 +1,166 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The distributed command is an example of a distributed MCP server. +// +// It forks multiple child processes (according to the -child_ports flag), each +// of which is a streamable HTTP MCP server with the 'inc' tool, and proxies +// incoming http requests to them. +// +// Distributed MCP servers must be stateless, because there's no guarantee that +// subsequent requests for a session land on the same backend. However, they +// may still have logical session IDs, as can be seen with verbose logging +// (-v). +// +// Example: +// +// ./distributed -http=localhost:8080 -child_ports=8081,8082 +package main + +import ( + "context" + "flag" + "fmt" + "log" + "net/http" + "net/http/httputil" + "net/url" + "os" + "os/exec" + "os/signal" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +const childPortVar = "MCP_CHILD_PORT" + +var ( + httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") + childPorts = flag.String("child_ports", "", "comma-separated child ports to distribute to") + verbose = flag.Bool("v", false, "if set, enable verbose logging") +) + +func main() { + // This command runs as either a parent or a child, depending on whether + // childPortVar is set (a.k.a. the fork-and-exec trick). + // + // Each child is a streamable HTTP server, and the parent is a reverse proxy. + flag.Parse() + if v := os.Getenv(childPortVar); v != "" { + child(v) + } else { + parent() + } +} + +func parent() { + exe, err := os.Executable() + if err != nil { + log.Fatal(err) + } + + if *httpAddr == "" { + log.Fatal("must provide -http") + } + if *childPorts == "" { + log.Fatal("must provide -child_ports") + } + + // Ensure that children are cleaned up on CTRL-C + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) + defer stop() + + // Start the child processes. + ports := strings.Split(*childPorts, ",") + var wg sync.WaitGroup + childURLs := make([]*url.URL, len(ports)) + for i, port := range ports { + wg.Add(1) + childURL := fmt.Sprintf("http://localhost:%s", port) + childURLs[i], err = url.Parse(childURL) + if err != nil { + log.Fatal(err) + } + cmd := exec.CommandContext(ctx, exe, os.Args[1:]...) + cmd.Env = append(os.Environ(), fmt.Sprintf("%s=%s", childPortVar, port)) + cmd.Stderr = os.Stderr + go func() { + defer wg.Done() + log.Printf("starting child %d at %s", i, childURL) + if err := cmd.Run(); err != nil && ctx.Err() == nil { + log.Printf("child %d failed: %v", i, err) + } else { + log.Printf("child %d exited gracefully", i) + } + }() + } + + // Start a reverse proxy that round-robin's requests to each backend. + var nextBackend atomic.Int64 + server := http.Server{ + Addr: *httpAddr, + Handler: &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + child := int(nextBackend.Add(1)) % len(childURLs) + if *verbose { + log.Printf("dispatching to localhost:%s", ports[child]) + } + r.SetURL(childURLs[child]) + }, + }, + } + + wg.Add(1) + go func() { + defer wg.Done() + if err := server.ListenAndServe(); err != nil && ctx.Err() == nil { + log.Printf("Server failed: %v", err) + } + }() + + log.Printf("Serving at %s (CTRL-C to cancel)", *httpAddr) + + <-ctx.Done() + stop() // restore the interrupt signal + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Attempt the graceful shutdown. + if err := server.Shutdown(shutdownCtx); err != nil { + log.Fatalf("Server shutdown failed: %v", err) + } + + // Wait for the subprocesses and http server to stop. + wg.Wait() + + log.Println("Server shutdown gracefully.") +} + +func child(port string) { + // Create a server with a single tool that increments a counter. + server := mcp.NewServer(&mcp.Implementation{Name: "counter"}, nil) + + var count atomic.Int64 + inc := func(ctx context.Context, req *mcp.CallToolRequest, args struct{}) (*mcp.CallToolResult, struct{ Count int64 }, error) { + n := count.Add(1) + if *verbose { + log.Printf("request %d (session %s)", n, req.Session.ID()) + } + return nil, struct{ Count int64 }{n}, nil + } + mcp.AddTool(server, &mcp.Tool{Name: "inc"}, inc) + + handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, &mcp.StreamableHTTPOptions{ + Stateless: true, + }) + log.Printf("child listening on localhost:%s", port) + log.Fatal(http.ListenAndServe(fmt.Sprintf("localhost:%s", port), handler)) +} diff --git a/examples/server/everything/main.go b/examples/server/everything/main.go index d2b7b337..0b81919f 100644 --- a/examples/server/everything/main.go +++ b/examples/server/everything/main.go @@ -11,19 +11,40 @@ import ( "fmt" "log" "net/http" + _ "net/http/pprof" "net/url" "os" + "runtime" "strings" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" ) -var httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") +var ( + httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") + pprofAddr = flag.String("pprof", "", "if set, host the pprof debugging server at this address") +) func main() { flag.Parse() + if *pprofAddr != "" { + // For debugging memory leaks, add an endpoint to trigger a few garbage + // collection cycles and ensure the /debug/pprof/heap endpoint only reports + // reachable memory. + http.DefaultServeMux.Handle("/gc", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + for range 3 { + runtime.GC() + } + fmt.Fprintln(w, "GC'ed") + })) + go func() { + // DefaultServeMux was mutated by the /debug/pprof import. + http.ListenAndServe(*pprofAddr, http.DefaultServeMux) + }() + } + opts := &mcp.ServerOptions{ Instructions: "Use this server!", CompletionHandler: complete, // support completions by setting this handler @@ -56,7 +77,10 @@ func main() { return server }, nil) log.Printf("MCP handler listening at %s", *httpAddr) - http.ListenAndServe(*httpAddr, handler) + if *pprofAddr != "" { + log.Printf("pprof listening at http://%s/debug/pprof", *pprofAddr) + } + log.Fatal(http.ListenAndServe(*httpAddr, handler)) } else { t := &mcp.LoggingTransport{Transport: &mcp.StdioTransport{}, Writer: os.Stderr} if err := server.Run(context.Background(), t); err != nil { From 7769b2a959e8a793dabd9974994ba4ba139e1b84 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 3 Sep 2025 10:37:27 -0400 Subject: [PATCH 155/221] mcp: set content to marshaled output (#398) The ToolHandler constructed by ToolFor sets the result's Content to the marshaled output, following the spec's suggestion. See https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content. Also, set StructuredContent to the marshaled RawMessage to avoid a double marshal. Fixes #391. --- mcp/server.go | 30 ++++++++++++++++++++---------- mcp/streamable_test.go | 10 ++++++++-- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index c496b33a..2e7fbcc6 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -9,6 +9,7 @@ import ( "context" "encoding/base64" "encoding/gob" + "encoding/json" "fmt" "iter" "maps" @@ -261,26 +262,35 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan // TODO(v0.3.0): Validate out. _ = outputResolved - // TODO: return the serialized JSON in a TextContent block, as per spec? - // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content - // But people may use res.Content for other things. if res == nil { res = &CallToolResult{} } - if res.Content == nil { - res.Content = []Content{} // avoid returning 'null' - } - res.StructuredContent = out + // Marshal the output and put the RawMessage in the StructuredContent field. + var outval any = out if elemZero != nil { // Avoid typed nil, which will serialize as JSON null. - // Instead, use the zero value of the non-zero + // Instead, use the zero value of the unpointered type. var z Out if any(out) == any(z) { // zero is only non-nil if Out is a pointer type - res.StructuredContent = elemZero + outval = elemZero } } + outbytes, err := json.Marshal(outval) + if err != nil { + return nil, fmt.Errorf("marshaling output: %w", err) + } + res.StructuredContent = json.RawMessage(outbytes) // avoid a second marshal over the wire + + // If the Content field isn't being used, return the serialized JSON in a + // TextContent block, as the spec suggests: + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content. + if res.Content == nil { + res.Content = []Content{&TextContent{ + Text: string(outbytes), + }} + } return res, nil - } + } // end of handler return &tt, th, nil } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index c99ca782..0cb7f955 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1175,7 +1175,10 @@ func TestStreamableStateless(t *testing.T) { req(2, "tools/call", &CallToolParams{Name: "greet", Arguments: hiParams{Name: "World"}}), }, wantMessages: []jsonrpc.Message{ - resp(2, &CallToolResult{Content: []Content{&TextContent{Text: "hi World"}}}, nil), + resp(2, &CallToolResult{ + Content: []Content{&TextContent{Text: "hi World"}}, + StructuredContent: json.RawMessage("null"), + }, nil), }, wantSessionID: false, }, @@ -1186,7 +1189,10 @@ func TestStreamableStateless(t *testing.T) { req(2, "tools/call", &CallToolParams{Name: "greet", Arguments: hiParams{Name: "foo"}}), }, wantMessages: []jsonrpc.Message{ - resp(2, &CallToolResult{Content: []Content{&TextContent{Text: "hi foo"}}}, nil), + resp(2, &CallToolResult{ + Content: []Content{&TextContent{Text: "hi foo"}}, + StructuredContent: json.RawMessage("null"), + }, nil), }, wantSessionID: false, }, From 063fb121491cf576c3a425445b813af668c4bab4 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 3 Sep 2025 11:33:40 -0400 Subject: [PATCH 156/221] mcp: unexport ToolFor (#402) The rationale is provided on the issue. Also, document the low-level nature of Server.AddTool. (drive-by change) Fixes #401. --- mcp/server.go | 42 ++++++++++++++++-------------------------- 1 file changed, 16 insertions(+), 26 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 2e7fbcc6..114489ae 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -154,10 +154,18 @@ func (s *Server) RemovePrompts(names ...string) { // If present, the output schema must also have type "object". // // When the handler is invoked as part of a CallTool request, req.Params.Arguments -// will be a json.RawMessage. Unmarshaling the arguments and validating them against the -// input schema are the handler author's responsibility. +// will be a json.RawMessage. // -// Most users should use the top-level function [AddTool]. +// Unmarshaling the arguments and validating them against the input schema are the +// caller's responsibility. +// +// Validating the result against the output schema, if any, is the caller's responsibility. +// +// Setting the result's Content, StructuredContent and IsError fields are the caller's +// responsibility. +// +// Most users should use the top-level function [AddTool], which handles all these +// responsibilities. func (s *Server) AddTool(t *Tool, h ToolHandler) { if t.InputSchema == nil { // This prevents the tool author from forgetting to write a schema where @@ -181,26 +189,6 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { func() bool { s.tools.add(st); return true }) } -// ToolFor returns a shallow copy of t and a [ToolHandler] that wraps h. -// -// If the tool's input schema is nil, it is set to the schema inferred from the In -// type parameter, using [jsonschema.For]. The In type parameter must be a map -// or a struct, so that its inferred JSON Schema has type "object". -// -// For tools that don't return structured output, Out should be 'any'. -// Otherwise, if the tool's output schema is nil the output schema is set to -// the schema inferred from Out, which must be a map or a struct. -// -// Most users will call [AddTool]. Use [ToolFor] if you wish to modify the -// tool's schemas or wrap the ToolHandler before calling [Server.AddTool]. -func ToolFor[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler) { - tt, hh, err := toolForErr(t, h) - if err != nil { - panic(fmt.Sprintf("ToolFor: tool %q: %v", t.Name, err)) - } - return tt, hh -} - // TODO(v0.3.0): test func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { tt := *t @@ -335,10 +323,12 @@ func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) // For tools that don't return structured output, Out should be 'any'. // Otherwise, if the tool's output schema is nil the output schema is set to // the schema inferred from Out, which must be a map or a struct. -// -// It is a convenience for s.AddTool(ToolFor(t, h)). func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { - s.AddTool(ToolFor(t, h)) + tt, hh, err := toolForErr(t, h) + if err != nil { + panic(fmt.Sprintf("AddTool: tool %q: %v", t.Name, err)) + } + s.AddTool(tt, hh) } // RemoveTools removes the tools with the given names. From 0bb1a42b53c08869cdc65281504081706c3627b8 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 3 Sep 2025 11:34:02 -0400 Subject: [PATCH 157/221] examples/server/middleware: log tool result (#400) Log the result of a tool call in the middleware. This example now fully demonstrates that receiving middleware can effectively wrap a ToolHandler. That means that one reason for ToolFor is moot: you don't need to get your hands on the returned ToolHandler in order to wrap it. --- examples/server/middleware/main.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/server/middleware/main.go b/examples/server/middleware/main.go index 224c8c6f..4e37471e 100644 --- a/examples/server/middleware/main.go +++ b/examples/server/middleware/main.go @@ -64,6 +64,12 @@ func main() { "duration_ms", duration.Milliseconds(), "has_result", result != nil, ) + // Log more for tool results. + if ctr, ok := result.(*mcp.CallToolResult); ok { + logger.Info("tool result", + "isError", ctr.IsError, + "structuredContent", ctr.StructuredContent) + } } return result, err } @@ -103,7 +109,7 @@ func main() { Content: []mcp.Content{ &mcp.TextContent{Text: message}, }, - }, nil, nil + }, message, nil }, ) From c76d991f0533dac9d3668c8d3406f65d29d47d9a Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Wed, 3 Sep 2025 11:34:39 -0400 Subject: [PATCH 158/221] examples/server/distributed: address comment in #380 I failed to push my branch before merging #380. Address one minor comment from that PR. --- examples/server/distributed/main.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/server/distributed/main.go b/examples/server/distributed/main.go index 43d959e6..0433f1fa 100644 --- a/examples/server/distributed/main.go +++ b/examples/server/distributed/main.go @@ -80,7 +80,6 @@ func parent() { var wg sync.WaitGroup childURLs := make([]*url.URL, len(ports)) for i, port := range ports { - wg.Add(1) childURL := fmt.Sprintf("http://localhost:%s", port) childURLs[i], err = url.Parse(childURL) if err != nil { @@ -89,6 +88,8 @@ func parent() { cmd := exec.CommandContext(ctx, exe, os.Args[1:]...) cmd.Env = append(os.Environ(), fmt.Sprintf("%s=%s", childPortVar, port)) cmd.Stderr = os.Stderr + + wg.Add(1) go func() { defer wg.Done() log.Printf("starting child %d at %s", i, childURL) From 1c20560beb05b1b1c84431eb9dd96d7d17266731 Mon Sep 17 00:00:00 2001 From: ccpro <92025731+CCpro10@users.noreply.github.com> Date: Wed, 3 Sep 2025 23:35:19 +0800 Subject: [PATCH 159/221] mcp: export jsonResponse field Exporting StreamableHTTPOptions.JSONResponse is helpful for some users behind proxies that don't work well with text/event-stream. For #211 Co-authored-by: ccpro10 --- mcp/streamable.go | 8 ++++---- mcp/streamable_test.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index f56eabd9..f1a8ca12 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -72,8 +72,8 @@ type StreamableHTTPOptions struct { // TODO: support session retention (?) - // jsonResponse is forwarded to StreamableServerTransport.jsonResponse. - jsonResponse bool + // JSONResponse is forwarded to StreamableServerTransport.jsonResponse. + JSONResponse bool } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. @@ -233,7 +233,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque transport = &StreamableServerTransport{ SessionID: sessionID, Stateless: h.opts.Stateless, - jsonResponse: h.opts.jsonResponse, + jsonResponse: h.opts.JSONResponse, } // To support stateless mode, we initialize the session with a default @@ -487,7 +487,7 @@ type stream struct { // jsonResponse records whether this stream should respond with application/json // instead of text/event-stream. // - // See [StreamableServerTransportOptions.jsonResponse]. + // See [StreamableServerTransportOptions.JSONResponse]. jsonResponse bool // signal is a 1-buffered channel, owned by an incoming HTTP request, that signals diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 0cb7f955..49c2e87f 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -84,7 +84,7 @@ func TestStreamableTransports(t *testing.T) { // Start an httptest.Server with the StreamableHTTPHandler, wrapped in a // cookie-checking middleware. handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ - jsonResponse: useJSON, + JSONResponse: useJSON, }) var ( From 1250a31b5101e473f727a766c167f8d79bf64c33 Mon Sep 17 00:00:00 2001 From: Shusaku Yasoda <136243871+yasomaru@users.noreply.github.com> Date: Thu, 4 Sep 2025 02:05:07 +0900 Subject: [PATCH 160/221] auth: add OAuth error handling in TokenVerifier and tests (#399) Add ErrOAuth error type and handling to match TypeScript SDK behavior. OAuth protocol errors now return HTTP 400 instead of 500, providing better error classification for authentication issues. Changes - Add ErrOAuth variable for OAuth-specific protocol errors - Update verify function to return 400 for OAuth errors - Add test case for OAuth error handling Fixes compatibility with TypeScript SDK error handling patterns. --- auth/auth.go | 7 ++++++- auth/auth_test.go | 6 ++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/auth/auth.go b/auth/auth.go index 14ad28c7..ce908f62 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -24,6 +24,9 @@ type TokenInfo struct { // The error that a TokenVerifier should return if the token cannot be verified. var ErrInvalidToken = errors.New("invalid token") +// The error that a TokenVerifier should return for OAuth-specific protocol errors. +var ErrOAuth = errors.New("oauth error") + // A TokenVerifier checks the validity of a bearer token, and extracts information // from it. If verification fails, it should return an error that unwraps to ErrInvalidToken. type TokenVerifier func(ctx context.Context, token string) (*TokenInfo, error) @@ -88,7 +91,9 @@ func verify(ctx context.Context, verifier TokenVerifier, opts *RequireBearerToke if errors.Is(err, ErrInvalidToken) { return nil, err.Error(), http.StatusUnauthorized } - // TODO: the TS SDK distinguishes another error, OAuthError, and returns a 400. + if errors.Is(err, ErrOAuth) { + return nil, err.Error(), http.StatusBadRequest + } // Investigate how that works. // See typescript-sdk/src/server/auth/middleware/bearerAuth.ts. return nil, err.Error(), http.StatusInternalServerError diff --git a/auth/auth_test.go b/auth/auth_test.go index 715b9bba..8da41ec6 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -19,6 +19,8 @@ func TestVerify(t *testing.T) { return &TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil case "invalid": return nil, ErrInvalidToken + case "oauth": + return nil, ErrOAuth case "noexp": return &TokenInfo{}, nil case "expired": @@ -47,6 +49,10 @@ func TestVerify(t *testing.T) { "invalid", nil, "bearer invalid", "invalid token", 401, }, + { + "oauth error", nil, "Bearer oauth", + "oauth error", 400, + }, { "no expiration", nil, "Bearer noexp", "token missing expiration", 401, From 29c1650e2af89513c70a5b26cb78e83f67f84441 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 3 Sep 2025 14:20:53 -0400 Subject: [PATCH 161/221] internal/oauthex: OAuth extensions (#125) Add a package for the extensions to OAuth 2.0 required by MCP. This first PR adds Protected Resource Metadata. --- internal/oauthex/oauth2.go | 6 + internal/oauthex/oauth2_test.go | 270 +++++++++++++++++++++ internal/oauthex/resource_meta.go | 382 ++++++++++++++++++++++++++++++ 3 files changed, 658 insertions(+) create mode 100644 internal/oauthex/oauth2.go create mode 100644 internal/oauthex/oauth2_test.go create mode 100644 internal/oauthex/resource_meta.go diff --git a/internal/oauthex/oauth2.go b/internal/oauthex/oauth2.go new file mode 100644 index 00000000..d1166fe1 --- /dev/null +++ b/internal/oauthex/oauth2.go @@ -0,0 +1,6 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package oauthex implements extensions to OAuth2. +package oauthex diff --git a/internal/oauthex/oauth2_test.go b/internal/oauthex/oauth2_test.go new file mode 100644 index 00000000..92017f81 --- /dev/null +++ b/internal/oauthex/oauth2_test.go @@ -0,0 +1,270 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package oauthex + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func TestSplitChallenges(t *testing.T) { + tests := []struct { + name string + input string + want []string + }{ + { + name: "single challenge no params", + input: `Basic`, + want: []string{`Basic`}, + }, + { + name: "single challenge with params", + input: `Bearer realm="example.com", error="invalid_token"`, + want: []string{`Bearer realm="example.com", error="invalid_token"`}, + }, + { + name: "single challenge with comma in quoted string", + input: `Bearer realm="example, with comma"`, + want: []string{`Bearer realm="example, with comma"`}, + }, + { + name: "two challenges", + input: `Basic, Bearer realm="example"`, + want: []string{`Basic`, ` Bearer realm="example"`}, + }, + { + name: "multiple challenges complex", + input: `Newauth realm="apps", Basic, Bearer realm="example.com", error="invalid_token"`, + want: []string{`Newauth realm="apps"`, ` Basic`, ` Bearer realm="example.com", error="invalid_token"`}, + }, + { + name: "challenge with escaped quote", + input: `Bearer realm="example \"quoted\""`, + want: []string{`Bearer realm="example \"quoted\""`}, + }, + { + name: "empty input", + input: "", + want: []string{""}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := splitChallenges(tt.input) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("splitChallenges() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSplitChallengesError(t *testing.T) { + if _, err := splitChallenges(`"Bearer"`); err == nil { + t.Fatal("got nil, want error") + } +} + +func TestParseSingleChallenge(t *testing.T) { + tests := []struct { + name string + input string + want challenge + wantErr bool + }{ + { + name: "scheme only", + input: "Basic", + want: challenge{ + Scheme: "basic", + }, + wantErr: false, + }, + { + name: "scheme with one quoted param", + input: `Bearer realm="example.com"`, + want: challenge{ + Scheme: "bearer", + Params: map[string]string{"realm": "example.com"}, + }, + wantErr: false, + }, + { + name: "scheme with one unquoted param", + input: `Bearer realm=example.com`, + want: challenge{ + Scheme: "bearer", + Params: map[string]string{"realm": "example.com"}, + }, + wantErr: false, + }, + { + name: "scheme with multiple params", + input: `Bearer realm="example", error="invalid_token", error_description="The token expired"`, + want: challenge{ + Scheme: "bearer", + Params: map[string]string{ + "realm": "example", + "error": "invalid_token", + "error_description": "The token expired", + }, + }, + wantErr: false, + }, + { + name: "scheme with multiple unquoted params", + input: `Bearer realm=example, error=invalid_token, error_description=The token expired`, + want: challenge{ + Scheme: "bearer", + Params: map[string]string{ + "realm": "example", + "error": "invalid_token", + "error_description": "The token expired", + }, + }, + wantErr: false, + }, + { + name: "case-insensitive scheme and keys", + input: `BEARER ReAlM="example"`, + want: challenge{ + Scheme: "bearer", + Params: map[string]string{"realm": "example"}, + }, + wantErr: false, + }, + { + name: "param with escaped quote", + input: `Bearer realm="example \"foo\" bar"`, + want: challenge{ + Scheme: "bearer", + Params: map[string]string{"realm": `example "foo" bar`}, + }, + wantErr: false, + }, + { + name: "param without quotes (token)", + input: "Bearer realm=example.com", + want: challenge{ + Scheme: "bearer", + Params: map[string]string{"realm": "example.com"}, + }, + wantErr: false, + }, + { + name: "malformed param - no value", + input: "Bearer realm=", + wantErr: true, + }, + { + name: "malformed param - unterminated quote", + input: `Bearer realm="example`, + wantErr: true, + }, + { + name: "malformed param - missing comma", + input: `Bearer realm="a" error="b"`, + wantErr: true, + }, + { + name: "malformed param - initial equal", + input: `Bearer ="a"`, + wantErr: true, + }, + { + name: "empty input", + input: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseSingleChallenge(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("parseSingleChallenge() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseSingleChallenge() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetProtectedResourceMetadata(t *testing.T) { + ctx := context.Background() + t.Run("FromHeader", func(t *testing.T) { + h := &fakeResourceHandler{serveWWWAuthenticate: true} + server := httptest.NewTLSServer(h) + h.installHandlers(server.URL) + client := server.Client() + res, err := client.Get(server.URL + "/resource") + if err != nil { + t.Fatal(err) + } + if res.StatusCode != http.StatusUnauthorized { + t.Fatal("want unauth") + } + prm, err := GetProtectedResourceMetadataFromHeader(ctx, res.Header, client) + if err != nil { + t.Fatal(err) + } + if prm == nil { + t.Fatal("nil prm") + } + }) + t.Run("FromID", func(t *testing.T) { + h := &fakeResourceHandler{serveWWWAuthenticate: false} + server := httptest.NewTLSServer(h) + h.installHandlers(server.URL) + client := server.Client() + prm, err := GetProtectedResourceMetadataFromID(ctx, server.URL, client) + if err != nil { + t.Fatal(err) + } + if prm == nil { + t.Fatal("nil prm") + } + }) +} + +type fakeResourceHandler struct { + http.ServeMux + serveWWWAuthenticate bool +} + +func (h *fakeResourceHandler) installHandlers(serverURL string) { + path := "/.well-known/oauth-protected-resource" + url := serverURL + path + h.Handle("GET /resource", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if h.serveWWWAuthenticate { + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, url)) + } + w.WriteHeader(http.StatusUnauthorized) + })) + h.Handle("GET "+path, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + // If there is a WWW-Authenticate header, the resource field is the value of that header. + // If not, it's the server URL without the "/.well-known/..." part. + resource := serverURL + if h.serveWWWAuthenticate { + resource = url + } + prm := &ProtectedResourceMetadata{Resource: resource} + if err := json.NewEncoder(w).Encode(prm); err != nil { + panic(err) + } + })) +} diff --git a/internal/oauthex/resource_meta.go b/internal/oauthex/resource_meta.go new file mode 100644 index 00000000..eb981d2d --- /dev/null +++ b/internal/oauthex/resource_meta.go @@ -0,0 +1,382 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements Protected Resource Metadata. +// See https://www.rfc-editor.org/rfc/rfc9728.html. + +package oauthex + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "path" + "strings" + "unicode" + + "github.com/modelcontextprotocol/go-sdk/internal/util" +) + +const defaultProtectedResourceMetadataURI = "/.well-known/oauth-protected-resource" + +// ProtectedResourceMetadata is the metadata for an OAuth 2.0 protected resource, +// as defined in section 2 of https://www.rfc-editor.org/rfc/rfc9728.html. +// +// The following features are not supported: +// - additional keys (§2, last sentence) +// - human-readable metadata (§2.1) +// - signed metadata (§2.2) +type ProtectedResourceMetadata struct { + // GENERATED BY GEMINI 2.5. + + // Resource (resource) is the protected resource's resource identifier. + // Required. + Resource string `json:"resource"` + + // AuthorizationServers (authorization_servers) is an optional slice containing a list of + // OAuth authorization server issuer identifiers (as defined in RFC 8414) that can be + // used with this protected resource. + AuthorizationServers []string `json:"authorization_servers,omitempty"` + + // JWKSURI (jwks_uri) is an optional URL of the protected resource's JSON Web Key (JWK) Set + // document. This contains public keys belonging to the protected resource, such as + // signing key(s) that the resource server uses to sign resource responses. + JWKSURI string `json:"jwks_uri,omitempty"` + + // ScopesSupported (scopes_supported) is a recommended slice containing a list of scope + // values (as defined in RFC 6749) used in authorization requests to request access + // to this protected resource. + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // BearerMethodsSupported (bearer_methods_supported) is an optional slice containing + // a list of the supported methods of sending an OAuth 2.0 bearer token to the + // protected resource. Defined values are "header", "body", and "query". + BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"` + + // ResourceSigningAlgValuesSupported (resource_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms (alg values) supported by the protected + // resource for signing resource responses. + ResourceSigningAlgValuesSupported []string `json:"resource_signing_alg_values_supported,omitempty"` + + // ResourceName (resource_name) is a human-readable name of the protected resource + // intended for display to the end user. It is RECOMMENDED that this field be included. + // This value may be internationalized. + ResourceName string `json:"resource_name,omitempty"` + + // ResourceDocumentation (resource_documentation) is an optional URL of a page containing + // human-readable information for developers using the protected resource. + // This value may be internationalized. + ResourceDocumentation string `json:"resource_documentation,omitempty"` + + // ResourcePolicyURI (resource_policy_uri) is an optional URL of a page containing + // human-readable policy information on how a client can use the data provided. + // This value may be internationalized. + ResourcePolicyURI string `json:"resource_policy_uri,omitempty"` + + // ResourceTOSURI (resource_tos_uri) is an optional URL of a page containing the protected + // resource's human-readable terms of service. This value may be internationalized. + ResourceTOSURI string `json:"resource_tos_uri,omitempty"` + + // TLSClientCertificateBoundAccessTokens (tls_client_certificate_bound_access_tokens) is an + // optional boolean indicating support for mutual-TLS client certificate-bound + // access tokens (RFC 8705). Defaults to false if omitted. + TLSClientCertificateBoundAccessTokens bool `json:"tls_client_certificate_bound_access_tokens,omitempty"` + + // AuthorizationDetailsTypesSupported (authorization_details_types_supported) is an optional + // slice of 'type' values supported by the resource server for the + // 'authorization_details' parameter (RFC 9396). + AuthorizationDetailsTypesSupported []string `json:"authorization_details_types_supported,omitempty"` + + // DPOPSigningAlgValuesSupported (dpop_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms supported by the resource server for validating + // DPoP proof JWTs (RFC 9449). + DPOPSigningAlgValuesSupported []string `json:"dpop_signing_alg_values_supported,omitempty"` + + // DPOPBoundAccessTokensRequired (dpop_bound_access_tokens_required) is an optional boolean + // specifying whether the protected resource always requires the use of DPoP-bound + // access tokens (RFC 9449). Defaults to false if omitted. + DPOPBoundAccessTokensRequired bool `json:"dpop_bound_access_tokens_required,omitempty"` + + // SignedMetadata (signed_metadata) is an optional JWT containing metadata parameters + // about the protected resource as claims. If present, these values take precedence + // over values conveyed in plain JSON. + // TODO:implement. + // Note that §2.2 says it's okay to ignore this. + // SignedMetadata string `json:"signed_metadata,omitempty"` +} + +// GetProtectedResourceMetadataFromID issues a GET request to retrieve protected resource +// metadata from a resource server by its ID. +// The resource ID is an HTTPS URL, typically with a host:port and possibly a path. +// For example: +// +// https://example.com/server +// +// This function, following the spec (§3), inserts the default well-known path into the +// URL. In our example, the result would be +// +// https://example.com/.well-known/oauth-protected-resource/server +// +// It then retrieves the metadata at that location using the given client (or the +// default client if nil) and validates its resource field against resourceID. +func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, c *http.Client) (_ *ProtectedResourceMetadata, err error) { + defer util.Wrapf(&err, "GetProtectedResourceMetadataFromID(%q)", resourceID) + + u, err := url.Parse(resourceID) + if err != nil { + return nil, err + } + // Insert well-known URI into URL. + u.Path = path.Join(defaultProtectedResourceMetadataURI, u.Path) + return getPRM(ctx, u.String(), c, resourceID) +} + +// GetProtectedResourceMetadataFromHeader retrieves protected resource metadata +// using information in the given header, using the given client (or the default +// client if nil). +// It issues a GET request to a URL discovered by parsing the WWW-Authenticate headers in the given request, +// It then validates the resource field of the resulting metadata against the given URL. +// If there is no URL in the request, it returns nil, nil. +func GetProtectedResourceMetadataFromHeader(ctx context.Context, header http.Header, c *http.Client) (_ *ProtectedResourceMetadata, err error) { + defer util.Wrapf(&err, "GetProtectedResourceMetadataFromHeader") + headers := header[http.CanonicalHeaderKey("WWW-Authenticate")] + if len(headers) == 0 { + return nil, nil + } + cs, err := parseWWWAuthenticate(headers) + if err != nil { + return nil, err + } + url := resourceMetadataURL(cs) + if url == "" { + return nil, nil + } + return getPRM(ctx, url, c, url) +} + +// getPRM makes a GET request to the given URL, and validates the response. +// As part of the validation, it compares the returned resource field to wantResource. +func getPRM(ctx context.Context, url string, c *http.Client, wantResource string) (*ProtectedResourceMetadata, error) { + if !strings.HasPrefix(strings.ToUpper(url), "HTTPS://") { + return nil, fmt.Errorf("resource URL %q does not use HTTPS", url) + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + if c == nil { + c = http.DefaultClient + } + res, err := c.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + // Spec §3.2 requires a 200. + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("bad status %s", res.Status) + } + // Spec §3.2 requires application/json. + if ct := res.Header.Get("Content-Type"); ct != "application/json" { + return nil, fmt.Errorf("bad content type %q", ct) + } + + var prm ProtectedResourceMetadata + dec := json.NewDecoder(res.Body) + if err := dec.Decode(&prm); err != nil { + return nil, err + } + // Validate the Resource field to thwart impersonation attacks (section 3.3). + if prm.Resource != wantResource { + return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, wantResource) + } + return &prm, nil +} + +// challenge represents a single authentication challenge from a WWW-Authenticate header. +// As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. +type challenge struct { + // GENERATED BY GEMINI 2.5. + // + // Scheme is the authentication scheme (e.g., "Bearer", "Basic"). + // It is case-insensitive. A parsed value will always be lower-case. + Scheme string + // Params is a map of authentication parameters. + // Keys are case-insensitive. Parsed keys are always lower-case. + Params map[string]string +} + +// resourceMetadataURL returns a resource metadata URL from the given challenges, +// or the empty string if there is none. +func resourceMetadataURL(cs []challenge) string { + for _, c := range cs { + if u := c.Params["resource_metadata"]; u != "" { + return u + } + } + return "" +} + +// parseWWWAuthenticate parses a WWW-Authenticate header string. +// The header format is defined in RFC 9110, Section 11.6.1, and can contain +// one or more challenges, separated by commas. +// It returns a slice of challenges or an error if one of the headers is malformed. +func parseWWWAuthenticate(headers []string) ([]challenge, error) { + // GENERATED BY GEMINI 2.5 (human-tweaked) + var challenges []challenge + for _, h := range headers { + challengeStrings, err := splitChallenges(h) + if err != nil { + return nil, err + } + for _, cs := range challengeStrings { + if strings.TrimSpace(cs) == "" { + continue + } + challenge, err := parseSingleChallenge(cs) + if err != nil { + return nil, fmt.Errorf("failed to parse challenge %q: %w", cs, err) + } + challenges = append(challenges, challenge) + } + } + return challenges, nil +} + +// splitChallenges splits a header value containing one or more challenges. +// It correctly handles commas within quoted strings and distinguishes between +// commas separating auth-params and commas separating challenges. +func splitChallenges(header string) ([]string, error) { + // GENERATED BY GEMINI 2.5. + var challenges []string + inQuotes := false + start := 0 + for i, r := range header { + if r == '"' { + if i > 0 && header[i-1] != '\\' { + inQuotes = !inQuotes + } else if i == 0 { + // A challenge begins with an auth-scheme, which is a token, which cannot contain + // a quote. + return nil, errors.New(`challenge begins with '"'`) + } + } else if r == ',' && !inQuotes { + // This is a potential challenge separator. + // A new challenge does not start with `key=value`. + // We check if the part after the comma looks like a parameter. + lookahead := strings.TrimSpace(header[i+1:]) + eqPos := strings.Index(lookahead, "=") + + isParam := false + if eqPos > 0 { + // Check if the part before '=' is a single token (no spaces). + token := lookahead[:eqPos] + if strings.IndexFunc(token, unicode.IsSpace) == -1 { + isParam = true + } + } + + if !isParam { + // The part after the comma does not look like a parameter, + // so this comma separates challenges. + challenges = append(challenges, header[start:i]) + start = i + 1 + } + } + } + // Add the last (or only) challenge to the list. + challenges = append(challenges, header[start:]) + return challenges, nil +} + +// parseSingleChallenge parses a string containing exactly one challenge. +// challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ] +func parseSingleChallenge(s string) (challenge, error) { + // GENERATED BY GEMINI 2.5, human-tweaked. + s = strings.TrimSpace(s) + if s == "" { + return challenge{}, errors.New("empty challenge string") + } + + scheme, paramsStr, found := strings.Cut(s, " ") + c := challenge{Scheme: strings.ToLower(scheme)} + if !found { + return c, nil + } + + params := make(map[string]string) + + // Parse the key-value parameters. + for paramsStr != "" { + // Find the end of the parameter key. + keyEnd := strings.Index(paramsStr, "=") + if keyEnd <= 0 { + return challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) + } + key := strings.TrimSpace(paramsStr[:keyEnd]) + + // Move the string past the key and the '='. + paramsStr = strings.TrimSpace(paramsStr[keyEnd+1:]) + + var value string + if strings.HasPrefix(paramsStr, "\"") { + // The value is a quoted string. + paramsStr = paramsStr[1:] // Consume the opening quote. + var valBuilder strings.Builder + i := 0 + for ; i < len(paramsStr); i++ { + // Handle escaped characters. + if paramsStr[i] == '\\' && i+1 < len(paramsStr) { + valBuilder.WriteByte(paramsStr[i+1]) + i++ // We've consumed two characters. + } else if paramsStr[i] == '"' { + // End of the quoted string. + break + } else { + valBuilder.WriteByte(paramsStr[i]) + } + } + + // A quoted string must be terminated. + if i == len(paramsStr) { + return challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") + } + + value = valBuilder.String() + // Move the string past the value and the closing quote. + paramsStr = strings.TrimSpace(paramsStr[i+1:]) + } else { + // The value is a token. It ends at the next comma or the end of the string. + commaPos := strings.Index(paramsStr, ",") + if commaPos == -1 { + value = paramsStr + paramsStr = "" + } else { + value = strings.TrimSpace(paramsStr[:commaPos]) + paramsStr = strings.TrimSpace(paramsStr[commaPos:]) // Keep comma for next check + } + } + if value == "" { + return challenge{}, fmt.Errorf("no value for auth param %q", key) + } + + // Per RFC 9110, parameter keys are case-insensitive. + params[strings.ToLower(key)] = value + + // If there is a comma, consume it and continue to the next parameter. + if strings.HasPrefix(paramsStr, ",") { + paramsStr = strings.TrimSpace(paramsStr[1:]) + } else if paramsStr != "" { + // If there's content but it's not a new parameter, the format is wrong. + return challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) + } + } + + // Per RFC 9110, the scheme is case-insensitive. + return challenge{Scheme: strings.ToLower(scheme), Params: params}, nil +} From d0c5943f25c91bf12887ec70746ea99f8056e8b0 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 3 Sep 2025 16:23:42 -0400 Subject: [PATCH 162/221] add HTTP Request to TokenVerifier (#404) The request might be needed to verify the token. Fixes #403. --- auth/auth.go | 14 ++++++++------ auth/auth_test.go | 8 +++++--- examples/server/auth-middleware/main.go | 5 ++--- mcp/streamable_test.go | 2 +- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index ce908f62..7cc0074a 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -29,7 +29,8 @@ var ErrOAuth = errors.New("oauth error") // A TokenVerifier checks the validity of a bearer token, and extracts information // from it. If verification fails, it should return an error that unwraps to ErrInvalidToken. -type TokenVerifier func(ctx context.Context, token string) (*TokenInfo, error) +// The HTTP request is provided in case verifying the token involves checking it. +type TokenVerifier func(ctx context.Context, token string, req *http.Request) (*TokenInfo, error) // RequireBearerTokenOptions are options for [RequireBearerToken]. type RequireBearerTokenOptions struct { @@ -55,6 +56,8 @@ func TokenInfoFromContext(ctx context.Context) *TokenInfo { // If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds. // If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header // is populated to enable [protected resource metadata]. +// + // // [protected resource metadata]: https://datatracker.ietf.org/doc/rfc9728 func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions) func(http.Handler) http.Handler { @@ -62,7 +65,7 @@ func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions) return func(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tokenInfo, errmsg, code := verify(r.Context(), verifier, opts, r.Header.Get("Authorization")) + tokenInfo, errmsg, code := verify(r, verifier, opts) if code != 0 { if code == http.StatusUnauthorized || code == http.StatusForbidden { if opts != nil && opts.ResourceMetadataURL != "" { @@ -78,15 +81,16 @@ func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions) } } -func verify(ctx context.Context, verifier TokenVerifier, opts *RequireBearerTokenOptions, authHeader string) (_ *TokenInfo, errmsg string, code int) { +func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenOptions) (_ *TokenInfo, errmsg string, code int) { // Extract bearer token. + authHeader := req.Header.Get("Authorization") fields := strings.Fields(authHeader) if len(fields) != 2 || strings.ToLower(fields[0]) != "bearer" { return nil, "no bearer token", http.StatusUnauthorized } // Verify the token and get information from it. - tokenInfo, err := verifier(ctx, fields[1]) + tokenInfo, err := verifier(req.Context(), fields[1], req) if err != nil { if errors.Is(err, ErrInvalidToken) { return nil, err.Error(), http.StatusUnauthorized @@ -94,8 +98,6 @@ func verify(ctx context.Context, verifier TokenVerifier, opts *RequireBearerToke if errors.Is(err, ErrOAuth) { return nil, err.Error(), http.StatusBadRequest } - // Investigate how that works. - // See typescript-sdk/src/server/auth/middleware/bearerAuth.ts. return nil, err.Error(), http.StatusInternalServerError } diff --git a/auth/auth_test.go b/auth/auth_test.go index 8da41ec6..ef8ea7b3 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -7,13 +7,13 @@ package auth import ( "context" "errors" + "net/http" "testing" "time" ) func TestVerify(t *testing.T) { - ctx := context.Background() - verifier := func(_ context.Context, token string) (*TokenInfo, error) { + verifier := func(_ context.Context, token string, _ *http.Request) (*TokenInfo, error) { switch token { case "valid": return &TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil @@ -67,7 +67,9 @@ func TestVerify(t *testing.T) { }, } { t.Run(tt.name, func(t *testing.T) { - _, gotMsg, gotCode := verify(ctx, verifier, tt.opts, tt.header) + _, gotMsg, gotCode := verify(&http.Request{ + Header: http.Header{"Authorization": {tt.header}}, + }, verifier, tt.opts) if gotMsg != tt.wantMsg || gotCode != tt.wantCode { t.Errorf("got (%q, %d), want (%q, %d)", gotMsg, gotCode, tt.wantMsg, tt.wantCode) } diff --git a/examples/server/auth-middleware/main.go b/examples/server/auth-middleware/main.go index f472b760..dd1271eb 100644 --- a/examples/server/auth-middleware/main.go +++ b/examples/server/auth-middleware/main.go @@ -83,7 +83,7 @@ func generateToken(userID string, scopes []string, expiresIn time.Duration) (str // verifyJWT verifies JWT tokens and returns TokenInfo for the auth middleware. // This function implements the TokenVerifier interface required by auth.RequireBearerToken. -func verifyJWT(ctx context.Context, tokenString string) (*auth.TokenInfo, error) { +func verifyJWT(ctx context.Context, tokenString string, _ *http.Request) (*auth.TokenInfo, error) { // Parse and validate the JWT token. token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) { // Verify the signing method is HMAC. @@ -92,7 +92,6 @@ func verifyJWT(ctx context.Context, tokenString string) (*auth.TokenInfo, error) } return jwtSecret, nil }) - if err != nil { // Return standard error for invalid tokens. return nil, fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) @@ -111,7 +110,7 @@ func verifyJWT(ctx context.Context, tokenString string) (*auth.TokenInfo, error) // verifyAPIKey verifies API keys and returns TokenInfo for the auth middleware. // This function implements the TokenVerifier interface required by auth.RequireBearerToken. -func verifyAPIKey(ctx context.Context, apiKey string) (*auth.TokenInfo, error) { +func verifyAPIKey(ctx context.Context, apiKey string, _ *http.Request) (*auth.TokenInfo, error) { // Look up the API key in our storage. key, exists := apiKeys[apiKey] if !exists { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 49c2e87f..2c90425d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1270,7 +1270,7 @@ func TestTokenInfo(t *testing.T) { AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo) streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) - verifier := func(context.Context, string) (*auth.TokenInfo, error) { + verifier := func(context.Context, string, *http.Request) (*auth.TokenInfo, error) { return &auth.TokenInfo{ Scopes: []string{"scope"}, // Expiration is far, far in the future. From 8314ec099b7145dc2a4891bb5fa00c184c10d367 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 3 Sep 2025 16:25:48 -0400 Subject: [PATCH 163/221] mcp: validate tool output (#352) Validate the handler's returned output against the output schema of a tool. Fixes #301. --- mcp/mcp_test.go | 20 ++++++++++++++++++++ mcp/server.go | 8 ++++++-- mcp/tool.go | 14 +++++++++----- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index a16ac838..28a8b974 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -257,6 +257,26 @@ func TestEndToEnd(t *testing.T) { t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) } + // Check output schema validation. + badout := &Tool{ + Name: "badout", + OutputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "x": {Type: "string"}, + }, + }, + } + AddTool(s, badout, func(_ context.Context, _ *CallToolRequest, arg map[string]any) (*CallToolResult, map[string]any, error) { + return nil, map[string]any{"x": 1}, nil + }) + _, err = cs.CallTool(ctx, &CallToolParams{Name: "badout"}) + wantMsg := `has type "integer", want "string"` + if err == nil || !strings.Contains(err.Error(), wantMsg) { + t.Errorf("\ngot %q\nwant error message containing %q", err, wantMsg) + } + + // Check tools-changed notifications. s.AddTool(&Tool{Name: "T", InputSchema: &jsonschema.Schema{Type: "object"}}, nopHandler) waitForNotification(t, "tools") s.RemoveTools("T") diff --git a/mcp/server.go b/mcp/server.go index 114489ae..0c1acc24 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -247,8 +247,12 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan }, nil } - // TODO(v0.3.0): Validate out. - _ = outputResolved + // Validate output schema, if any. + // Skip if out is nil: we've removed "null" from the output schema, so nil won't validate. + if v := reflect.ValueOf(out); v.Kind() == reflect.Pointer && v.IsNil() { + } else if err := validateSchema(outputResolved, &out); err != nil { + return nil, fmt.Errorf("tool output: %w", err) + } if res == nil { res = &CallToolResult{} diff --git a/mcp/tool.go b/mcp/tool.go index bd10a07c..53a3c7aa 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -9,6 +9,7 @@ import ( "context" "encoding/json" "fmt" + // "log" "github.com/google/jsonschema-go/jsonschema" ) @@ -42,13 +43,16 @@ func unmarshalSchema(data json.RawMessage, resolved *jsonschema.Resolved, v any) if err := dec.Decode(v); err != nil { return fmt.Errorf("unmarshaling: %w", err) } - // TODO: test with nil args. + return validateSchema(resolved, v) +} + +func validateSchema(resolved *jsonschema.Resolved, value any) error { if resolved != nil { - if err := resolved.ApplyDefaults(v); err != nil { - return fmt.Errorf("applying defaults from \n\t%s\nto\n\t%s:\n%w", schemaJSON(resolved.Schema()), data, err) + if err := resolved.ApplyDefaults(value); err != nil { + return fmt.Errorf("applying defaults from \n\t%s\nto\n\t%v:\n%w", schemaJSON(resolved.Schema()), value, err) } - if err := resolved.Validate(v); err != nil { - return fmt.Errorf("validating\n\t%s\nagainst\n\t %s:\n %w", data, schemaJSON(resolved.Schema()), err) + if err := resolved.Validate(value); err != nil { + return fmt.Errorf("validating\n\t%v\nagainst\n\t %s:\n %w", value, schemaJSON(resolved.Schema()), err) } } return nil From 92e19c1be8be2c863af290b1bf78a19c7623b2c2 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 3 Sep 2025 17:27:38 -0400 Subject: [PATCH 164/221] examples/server/auth-middleware: remove custom middleware (#405) The additional middleware added the TokenInfo the context. RequireBearerToken already does that. --- examples/server/auth-middleware/main.go | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/examples/server/auth-middleware/main.go b/examples/server/auth-middleware/main.go index dd1271eb..75a38b86 100644 --- a/examples/server/auth-middleware/main.go +++ b/examples/server/auth-middleware/main.go @@ -206,19 +206,6 @@ func CreateResource(ctx context.Context, req *mcp.CallToolRequest, args createRe }, nil, nil } -// authMiddleware extracts token information and adds it to the context -func authMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // In a real application, you would extract token info from the auth middleware's context - // For this example, we simulate the token info that would be available - ctx := context.WithValue(r.Context(), "user_info", &auth.TokenInfo{ - Scopes: []string{"read", "write"}, - Expiration: time.Now().Add(time.Hour), - }) - next.ServeHTTP(w, r.WithContext(ctx)) - }) -} - // createMCPServer creates an MCP server with authentication-aware tools func createMCPServer() *mcp.Server { server := mcp.NewServer(&mcp.Implementation{Name: "authenticated-mcp-server"}, nil) @@ -263,8 +250,8 @@ func main() { }, nil) // Apply authentication middleware to the MCP handler. - authenticatedHandler := jwtAuth(authMiddleware(handler)) - apiKeyHandler := apiKeyAuth(authMiddleware(handler)) + authenticatedHandler := jwtAuth(handler) + apiKeyHandler := apiKeyAuth(handler) // Create router for different authentication methods. http.HandleFunc("/mcp/jwt", authenticatedHandler.ServeHTTP) From a1eb4849b74ceaf7fdc99661b0533ab72bf73c1d Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 3 Sep 2025 20:09:19 +0000 Subject: [PATCH 165/221] mcp: correctly disallow GET requests on stateless servers The condition sessionID == "" was not quite right for disallowing GET requests. Since we decided to differentiate "stateless" vs "sessionless" servers, we need to disallow GET requests for both. For #393 --- mcp/streamable.go | 2 +- mcp/streamable_test.go | 36 ++++++++++++++++++++++-------------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index f1a8ca12..0469613b 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -171,7 +171,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque switch req.Method { case http.MethodPost, http.MethodGet: - if req.Method == http.MethodGet && sessionID == "" { + if req.Method == http.MethodGet && (h.opts.Stateless || sessionID == "") { http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed) return } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 2c90425d..6b69b210 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -799,7 +799,9 @@ func testStreamableHandler(t *testing.T, handler http.Handler, requests []stream out := make(chan jsonrpc.Message) // Cancel the step if we encounter a request that isn't going to be // handled. - ctx, cancel := context.WithCancel(context.Background()) + // + // Also, add a timeout (hopefully generous). + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) var wg sync.WaitGroup wg.Add(1) @@ -1168,6 +1170,11 @@ func TestStreamableStateless(t *testing.T) { wantBodyContaining: "greet", wantSessionID: false, }, + { + method: "GET", + wantStatusCode: http.StatusMethodNotAllowed, + wantSessionID: false, + }, { method: "POST", wantStatusCode: http.StatusOK, @@ -1215,33 +1222,34 @@ func TestStreamableStateless(t *testing.T) { } } - handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + sessionlessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ GetSessionID: func() string { return "" }, Stateless: true, }) - // Test the default stateless mode. - t.Run("stateless", func(t *testing.T) { - testStreamableHandler(t, handler, requests) - testClientCompatibility(t, handler) + // First, test the "sessionless" stateless mode, where there is no session ID. + t.Run("sessionless", func(t *testing.T) { + testStreamableHandler(t, sessionlessHandler, requests) + testClientCompatibility(t, sessionlessHandler) }) - // Test a "distributed" variant of stateless mode, where it has non-empty - // session IDs, but is otherwise stateless. + // Next, test the default stateless mode, where session IDs are permitted. // // This can be used by tools to look up application state preserved across // subsequent requests. for i, req := range requests { - // Now, we want a session for all requests. - req.wantSessionID = true + // Now, we want a session for all (valid) requests. + if req.wantStatusCode != http.StatusMethodNotAllowed { + req.wantSessionID = true + } requests[i] = req } - distributableHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + statelessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ Stateless: true, }) - t.Run("distributed", func(t *testing.T) { - testStreamableHandler(t, distributableHandler, requests) - testClientCompatibility(t, handler) + t.Run("stateless", func(t *testing.T) { + testStreamableHandler(t, statelessHandler, requests) + testClientCompatibility(t, sessionlessHandler) }) } From 4257528dde13d228f1fcfe3a9d60d9df4a0d4887 Mon Sep 17 00:00:00 2001 From: ln-12 <36760115+ln-12@users.noreply.github.com> Date: Thu, 4 Sep 2025 17:10:39 +0200 Subject: [PATCH 166/221] mcp: set default logging level to info (#411) During initialization, the default logging level is set to info allowing servers to emit logs immediately without waiting for a client to set a log level. The updated example shows how to send a notification in the context of the request. Fixes #387 --- examples/server/distributed/main.go | 5 ++++- mcp/streamable.go | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/server/distributed/main.go b/examples/server/distributed/main.go index 0433f1fa..b1440402 100644 --- a/examples/server/distributed/main.go +++ b/examples/server/distributed/main.go @@ -144,7 +144,7 @@ func parent() { } func child(port string) { - // Create a server with a single tool that increments a counter. + // Create a server with a single tool that increments a counter and sends a notification. server := mcp.NewServer(&mcp.Implementation{Name: "counter"}, nil) var count atomic.Int64 @@ -153,6 +153,9 @@ func child(port string) { if *verbose { log.Printf("request %d (session %s)", n, req.Session.ID()) } + // Send a notification in the context of the request + // Hint: in stateless mode, at least log level 'info' is required to send notifications + req.Session.Log(ctx, &mcp.LoggingMessageParams{Data: fmt.Sprintf("request %d (session %s)", n, req.Session.ID()), Level: "info"}) return nil, struct{ Count int64 }{n}, nil } mcp.AddTool(server, &mcp.Tool{Name: "inc"}, inc) diff --git a/mcp/streamable.go b/mcp/streamable.go index 0469613b..0f6bfdcc 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -282,6 +282,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque if !hasInitialized { state.InitializedParams = new(InitializedParams) } + state.LogLevel = "info" connectOpts = &ServerSessionOptions{ State: state, } From 203792dfdc04d75712bd6c469029a4622b939c42 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 4 Sep 2025 11:13:17 -0400 Subject: [PATCH 167/221] mcp: improve docs around tool handlers (#414) --- mcp/server.go | 10 +++++++++- mcp/tool.go | 6 ++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 0c1acc24..bb10079c 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -319,7 +319,6 @@ func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) } // AddTool adds a tool and typed tool handler to the server. -// // If the tool's input schema is nil, it is set to the schema inferred from the // In type parameter, using [jsonschema.For]. The In type parameter must be a // map or a struct, so that its inferred JSON Schema has type "object". @@ -327,6 +326,15 @@ func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) // For tools that don't return structured output, Out should be 'any'. // Otherwise, if the tool's output schema is nil the output schema is set to // the schema inferred from Out, which must be a map or a struct. +// +// The In argument to the handler will contain the unmarshaled arguments from +// CallToolRequest.Params.Arguments. Most users can ignore the [CallToolRequest] +// argument to the handler. +// +// The handler's Out return value will be used to populate [CallToolResult.StructuredContent]. +// If the handler returns a non-nil error, [CallToolResult.IsError] will be set to true, +// and [CallToolResult.Content] will be set to the text of the error. +// Most users can ignore the [CallToolResult] return value. func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { tt, hh, err := toolForErr(t, h) if err != nil { diff --git a/mcp/tool.go b/mcp/tool.go index 53a3c7aa..1e8ec8bd 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -15,11 +15,13 @@ import ( ) // A ToolHandler handles a call to tools/call. -// [CallToolParams.Arguments] will contain a map[string]any that has been validated -// against the input schema. +// This is a low-level API, for use with [Server.AddTool]. +// Most users will write a [ToolHandlerFor] and install it with [AddTool]. type ToolHandler func(context.Context, *CallToolRequest) (*CallToolResult, error) // A ToolHandlerFor handles a call to tools/call with typed arguments and results. +// Use [AddTool] to add a ToolHandlerFor to a server. +// Most users can ignore the [CallToolRequest] argument and [CallToolResult] return value. type ToolHandlerFor[In, Out any] func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) // A serverTool is a tool definition that is bound to a tool handler. From 07b9cee8ccbe0afa649822fd42b5473233d2f15b Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 4 Sep 2025 14:00:15 +0000 Subject: [PATCH 168/221] mcp: flush headers immediately for the hanging GET Flush headers immediately for the persistent hanging GET of the streamable transport; otherwise, clients may time out. Fixes #410 --- mcp/streamable.go | 37 ++++++++++++--------- mcp/streamable_test.go | 74 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 15 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 0f6bfdcc..e7777eb0 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -743,17 +743,28 @@ func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter // lastIndex is the index of the last seen event if resuming, else -1. func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int, persistent bool) { - writes := 0 - - // Accept checked in [StreamableHTTPHandler] + // Accept was checked in [StreamableHTTPHandler] w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] w.Header().Set("Connection", "keep-alive") if c.sessionID != "" { w.Header().Set(sessionIDHeader, c.sessionID) } + if persistent { + // Issue #410: the hanging GET is likely not to receive messages for a long + // time. Ensure that headers are flushed. + // + // For non-persistent requests, delay the writing of the header in case we + // may want to set an error status. + // (see the TODO: this probably isn't worth it). + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } // write one event containing data. + writes := 0 write := func(data []byte) bool { lastIndex++ e := Event{ @@ -770,23 +781,19 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, return true } - errorf := func(code int, format string, args ...any) { - if writes == 0 { - http.Error(w, fmt.Sprintf(format, args...), code) - } else { - // TODO(#170): log when we add server-side logging - } - } - // Repeatedly collect pending outgoing events and send them. ctx := req.Context() for msg, err := range c.messages(ctx, stream, persistent, lastIndex) { if err != nil { - if ctx.Err() != nil && writes == 0 { - // This probably doesn't matter, but respond with NoContent if the client disconnected. - w.WriteHeader(http.StatusNoContent) + if ctx.Err() == nil && writes == 0 && !persistent { + // If we haven't yet written the header, we have an opportunity to + // promote an error to an HTTP error. + // + // TODO: This may not matter in practice, in which case we should + // simplify. + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } else { - errorf(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + // TODO(#170): log when we add server-side logging } return } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 6b69b210..0d171d83 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1312,3 +1312,77 @@ func TestTokenInfo(t *testing.T) { t.Errorf("got %q, want %q", g, w) } } + +func TestStreamableGET(t *testing.T) { + // This test checks the fix for problematic behavior described in #410: + // Hanging GET headers should be written immediately, even if there are no + // messages. + server := NewServer(testImpl, nil) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + newReq := func(method string, msg jsonrpc.Message) *http.Request { + var body io.Reader + if msg != nil { + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + t.Fatal(err) + } + body = bytes.NewReader(data) + } + req, err := http.NewRequestWithContext(ctx, method, httpServer.URL, body) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "application/json, text/event-stream") + if msg != nil { + req.Header.Set("Content-Type", "application/json") + } + return req + } + + get1 := newReq(http.MethodGet, nil) + resp, err := http.DefaultClient.Do(get1) + if err != nil { + t.Fatal(err) + } + if got, want := resp.StatusCode, http.StatusMethodNotAllowed; got != want { + t.Errorf("initial GET: got status %d, want %d", got, want) + } + defer resp.Body.Close() + + post1 := newReq(http.MethodPost, req(1, methodInitialize, &InitializeParams{})) + resp, err = http.DefaultClient.Do(post1) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if got, want := resp.StatusCode, http.StatusOK; got != want { + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + t.Errorf("initialize POST: got status %d, want %d; body:\n%s", got, want, string(body)) + } + + sessionID := resp.Header.Get(sessionIDHeader) + if sessionID == "" { + t.Fatalf("initialized missing session ID") + } + + get2 := newReq("GET", nil) + get2.Header.Set(sessionIDHeader, sessionID) + resp, err = http.DefaultClient.Do(get2) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("GET with session ID: got status %d, want %d", got, want) + } +} From 30261728f72dc9f85dc2be3572699fefe097864d Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 4 Sep 2025 14:04:27 -0400 Subject: [PATCH 169/221] mcp: validate user-provided output schemas (#408) toolForErr was ignoring the output schema if the output type was any. That neglected the case where the user provided their own output schema. Fixes #371. --- mcp/server.go | 4 ++-- mcp/server_test.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++ mcp/tool.go | 1 - 3 files changed, 59 insertions(+), 3 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index bb10079c..19d902ff 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -212,7 +212,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan elemZero any // only non-nil if Out is a pointer type outputResolved *jsonschema.Resolved ) - if reflect.TypeFor[Out]() != reflect.TypeFor[any]() { + if t.OutputSchema != nil || reflect.TypeFor[Out]() != reflect.TypeFor[any]() { var err error elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved) if err != nil { @@ -302,8 +302,8 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan // TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we // should have a jsonschema.Zero(schema) helper? func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) (zero any, err error) { - rt := reflect.TypeFor[T]() if *sfield == nil { + rt := reflect.TypeFor[T]() if rt.Kind() == reflect.Pointer { rt = rt.Elem() zero = reflect.Zero(rt).Interface() diff --git a/mcp/server_test.go b/mcp/server_test.go index 7db40738..e46be379 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -6,6 +6,7 @@ package mcp import ( "context" + "encoding/json" "log" "slices" "testing" @@ -487,3 +488,59 @@ func TestAddTool(t *testing.T) { t.Error("bad Out: expected panic") } } + +type schema = jsonschema.Schema + +func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut *schema, wantErr bool) { + t.Helper() + th := func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) { + return nil, out, nil + } + gott, goth, err := toolForErr(tool, th) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(wantIn, gott.InputSchema); diff != "" { + t.Errorf("input: mismatch (-want, +got):\n%s", diff) + } + if diff := cmp.Diff(wantOut, gott.OutputSchema); diff != "" { + t.Errorf("output: mismatch (-want, +got):\n%s", diff) + } + ctr := &CallToolRequest{ + Params: &CallToolParamsRaw{ + Arguments: json.RawMessage(in), + }, + } + _, err = goth(context.Background(), ctr) + + if gotErr := err != nil; gotErr != wantErr { + t.Errorf("got error: %t, want error: %t", gotErr, wantErr) + } +} + +func TestToolForSchemas(t *testing.T) { + // Validate that ToolFor handles schemas properly. + + // Infer both schemas. + testToolForSchema[int](t, &Tool{}, "3", true, + &schema{Type: "integer"}, &schema{Type: "boolean"}, false) + // Validate the input schema: expect an error if it's wrong. + // We can't test that the output schema is validated, because it's typed. + testToolForSchema[int](t, &Tool{}, `"x"`, true, + &schema{Type: "integer"}, &schema{Type: "boolean"}, true) + + // Ignore type any for output. + testToolForSchema[int, any](t, &Tool{}, "3", 0, + &schema{Type: "integer"}, nil, false) + // Input is still validated. + testToolForSchema[int, any](t, &Tool{}, `"x"`, 0, + &schema{Type: "integer"}, nil, true) + + // Tool sets input schema: that is what's used. + testToolForSchema[int, any](t, &Tool{InputSchema: &schema{Type: "string"}}, "3", 0, + &schema{Type: "string"}, nil, true) // error: 3 is not a string + + // Tool sets output schema: that is what's used, and validation happens. + testToolForSchema[string, any](t, &Tool{OutputSchema: &schema{Type: "integer"}}, "3", "x", + &schema{Type: "string"}, &schema{Type: "integer"}, true) // error: "x" is not an integer +} diff --git a/mcp/tool.go b/mcp/tool.go index 1e8ec8bd..0797700f 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -9,7 +9,6 @@ import ( "context" "encoding/json" "fmt" - // "log" "github.com/google/jsonschema-go/jsonschema" ) From a4313f92bf5acf136777895ed34c16bdf8eb6468 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 4 Sep 2025 15:34:39 +0000 Subject: [PATCH 170/221] mcp: refactor the streamable client test to be more flexible Use a fake streamable server to facilitate testing client behavior. For this commit, just update the existing test (moved to a new file for isolation). Subsequent CLs will add more tests. Improve one client error message that occurred while debuging tests. For #393 --- mcp/streamable.go | 9 +- mcp/streamable_client_test.go | 185 ++++++++++++++++++++++++++++++++++ mcp/streamable_test.go | 71 ------------- 3 files changed, 193 insertions(+), 72 deletions(-) create mode 100644 mcp/streamable_client_test.go diff --git a/mcp/streamable.go b/mcp/streamable.go index e7777eb0..7a407538 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1234,7 +1234,14 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e default: resp.Body.Close() - return fmt.Errorf("unsupported content type %q", ct) + switch msg := msg.(type) { + case *jsonrpc.Request: + return fmt.Errorf("unsupported content type %q when sending %q (status: %d)", ct, msg.Method, resp.StatusCode) + case *jsonrpc.Response: + return fmt.Errorf("unsupported content type %q when sending jsonrpc response #%d (status: %d)", ct, msg.ID, resp.StatusCode) + default: + panic("unreachable") + } } return nil } diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go new file mode 100644 index 00000000..ee89df5a --- /dev/null +++ b/mcp/streamable_client_test.go @@ -0,0 +1,185 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +type streamableRequestKey struct { + httpMethod string // http method + sessionID string // session ID header + jsonrpcMethod string // jsonrpc method, or "" for non-requests +} + +type header map[string]string + +type streamableResponse struct { + header header + status int // or http.StatusOK + body string // or "" + optional bool // if set, request need not be sent + wantProtocolVersion string // if "", unchecked + callback func() // if set, called after the request is handled +} + +type fakeResponses map[streamableRequestKey]*streamableResponse + +type fakeStreamableServer struct { + t *testing.T + responses fakeResponses + + callMu sync.Mutex + calls map[streamableRequestKey]int +} + +func (s *fakeStreamableServer) missingRequests() []streamableRequestKey { + s.callMu.Lock() + defer s.callMu.Unlock() + + var unused []streamableRequestKey + for k, resp := range s.responses { + if s.calls[k] == 0 && !resp.optional { + unused = append(unused, k) + } + } + return unused +} + +func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + key := streamableRequestKey{ + httpMethod: req.Method, + sessionID: req.Header.Get(sessionIDHeader), + } + if req.Method == http.MethodPost { + body, err := io.ReadAll(req.Body) + if err != nil { + s.t.Errorf("failed to read body: %v", err) + http.Error(w, "failed to read body", http.StatusInternalServerError) + return + } + msg, err := jsonrpc.DecodeMessage(body) + if err != nil { + s.t.Errorf("invalid body: %v", err) + http.Error(w, "invalid body", http.StatusInternalServerError) + return + } + if r, ok := msg.(*jsonrpc.Request); ok { + key.jsonrpcMethod = r.Method + } + } + + s.callMu.Lock() + if s.calls == nil { + s.calls = make(map[streamableRequestKey]int) + } + s.calls[key]++ + s.callMu.Unlock() + + resp, ok := s.responses[key] + if !ok { + s.t.Errorf("missing response for %v", key) + http.Error(w, "no response", http.StatusInternalServerError) + return + } + if resp.callback != nil { + defer resp.callback() + } + for k, v := range resp.header { + w.Header().Set(k, v) + } + status := resp.status + if status == 0 { + status = http.StatusOK + } + w.WriteHeader(status) + + if v := req.Header.Get(protocolVersionHeader); v != resp.wantProtocolVersion && resp.wantProtocolVersion != "" { + s.t.Errorf("%v: bad protocol version header: got %q, want %q", key, v, resp.wantProtocolVersion) + } + w.Write([]byte(resp.body)) +} + +var ( + initResult = &InitializeResult{ + Capabilities: &ServerCapabilities{ + Completions: &CompletionCapabilities{}, + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, + }, + ProtocolVersion: latestProtocolVersion, + ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, + } + initResp = resp(1, initResult, nil) +) + +func jsonBody(t *testing.T, msg jsonrpc2.Message) string { + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + t.Fatalf("encoding failed: %v", err) + } + return string(data) +} + +func TestStreamableClientTransportLifecycle(t *testing.T) { + ctx := context.Background() + + // The lifecycle test verifies various behavior of the streamable client + // initialization: + // - check that it can handle application/json responses + // - check that it sends the negotiated protocol version + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + optional: true, + wantProtocolVersion: latestProtocolVersion, + }, + {"DELETE", "123", ""}: {}, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + if err := session.Close(); err != nil { + t.Errorf("closing session: %v", err) + } + if missing := fake.missingRequests(); len(missing) > 0 { + t.Errorf("did not receive expected requests: %v", missing) + } + if diff := cmp.Diff(initResult, session.state.InitializeResult); diff != "" { + t.Errorf("mismatch (-want, +got):\n%s", diff) + } +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 0d171d83..2963a04d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1035,77 +1035,6 @@ func mustMarshal(v any) json.RawMessage { return data } -func TestStreamableClientTransport(t *testing.T) { - // This test verifies various behavior of the streamable client transport: - // - check that it can handle application/json responses - // - check that it sends the negotiated protocol version - // - // TODO(rfindley): make this test more comprehensive, similar to - // [TestStreamableServerTransport]. - ctx := context.Background() - resp := func(id int64, result any, err error) *jsonrpc.Response { - return &jsonrpc.Response{ - ID: jsonrpc2.Int64ID(id), - Result: mustMarshal(result), - Error: err, - } - } - initResult := &InitializeResult{ - Capabilities: &ServerCapabilities{ - Completions: &CompletionCapabilities{}, - Logging: &LoggingCapabilities{}, - Tools: &ToolCapabilities{ListChanged: true}, - }, - ProtocolVersion: latestProtocolVersion, - ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, - } - initResp := resp(1, initResult, nil) - - var reqN atomic.Int32 // request count - serverHandler := func(w http.ResponseWriter, r *http.Request) { - rN := reqN.Add(1) - - // TODO(rfindley): if the status code is NoContent or Accepted, we should - // probably be tolerant of when the content type is not application/json. - w.Header().Set("Content-Type", "application/json") - if rN == 1 { - data, err := jsonrpc2.EncodeMessage(initResp) - if err != nil { - t.Errorf("encoding failed: %v", err) - } - w.Header().Set("Mcp-Session-Id", "123") - w.Write(data) - } else { - if v := r.Header.Get(protocolVersionHeader); v != latestProtocolVersion { - t.Errorf("bad protocol version header: got %q, want %q", v, latestProtocolVersion) - } - } - } - - httpServer := httptest.NewServer(http.HandlerFunc(serverHandler)) - defer httpServer.Close() - - transport := &StreamableClientTransport{Endpoint: httpServer.URL} - client := NewClient(testImpl, nil) - session, err := client.Connect(ctx, transport, nil) - if err != nil { - t.Fatalf("client.Connect() failed: %v", err) - } - if err := session.Close(); err != nil { - t.Errorf("closing session: %v", err) - } - - if got, want := reqN.Load(), int32(3); got < want { - // Expect at least 3 requests: initialize, initialized, and DELETE. - // We may or may not observe the GET, depending on timing. - t.Errorf("unexpected number of requests: got %d, want at least %d", got, want) - } - - if diff := cmp.Diff(initResult, session.state.InitializeResult); diff != "" { - t.Errorf("mismatch (-want, +got):\n%s", diff) - } -} - func TestEventID(t *testing.T) { tests := []struct { sid StreamID From 2c40bdc4d783023cd4b62649b9819db07aaa26f0 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 4 Sep 2025 17:16:14 +0000 Subject: [PATCH 171/221] mcp: systematically improve streamable client errors The streamable client connection can break for a variety of reasons, asynchronously to the client's request. Decorate these failures with additional context to clarify why they occurred. Add a test for the failure message of #393. Fixes #393 --- mcp/streamable.go | 43 +++++++++-------- mcp/streamable_client_test.go | 90 +++++++++++++++++++++++++++++++++++ mcp/transport.go | 2 +- 3 files changed, 114 insertions(+), 21 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 7a407538..25efe31a 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1130,7 +1130,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { // § 2.5: A server using the Streamable HTTP transport MAY assign a session // ID at initialization time, by including it in an Mcp-Session-Id header // on the HTTP response containing the InitializeResult. - go c.handleSSE(nil, true, nil) + go c.handleSSE("hanging GET", nil, true, nil) } // fail handles an asynchronous error while reading. @@ -1224,24 +1224,27 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } + var requestSummary string + switch msg := msg.(type) { + case *jsonrpc.Request: + requestSummary = fmt.Sprintf("sending %q", msg.Method) + case *jsonrpc.Response: + requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) + default: + panic("unreachable") + } + switch ct := resp.Header.Get("Content-Type"); ct { case "application/json": - go c.handleJSON(resp) + go c.handleJSON(requestSummary, resp) case "text/event-stream": jsonReq, _ := msg.(*jsonrpc.Request) - go c.handleSSE(resp, false, jsonReq) + go c.handleSSE(requestSummary, resp, false, jsonReq) default: resp.Body.Close() - switch msg := msg.(type) { - case *jsonrpc.Request: - return fmt.Errorf("unsupported content type %q when sending %q (status: %d)", ct, msg.Method, resp.StatusCode) - case *jsonrpc.Response: - return fmt.Errorf("unsupported content type %q when sending jsonrpc response #%d (status: %d)", ct, msg.ID, resp.StatusCode) - default: - panic("unreachable") - } + return fmt.Errorf("%s: unsupported content type %q", requestSummary, ct) } return nil } @@ -1265,16 +1268,16 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) { } } -func (c *streamableClientConn) handleJSON(resp *http.Response) { +func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Response) { body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - c.fail(err) + c.fail(fmt.Errorf("%s: failed to read body: %v", requestSummary, err)) return } msg, err := jsonrpc.DecodeMessage(body) if err != nil { - c.fail(fmt.Errorf("failed to decode response: %v", err)) + c.fail(fmt.Errorf("%s: failed to decode response: %v", requestSummary, err)) return } select { @@ -1289,12 +1292,12 @@ func (c *streamableClientConn) handleJSON(resp *http.Response) { // // If forReq is set, it is the request that initiated the stream, and the // stream is complete when we receive its response. -func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) { +func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) { resp := initialResp var lastEventID string for { if resp != nil { - eventID, clientClosed := c.processStream(resp, forReq) + eventID, clientClosed := c.processStream(requestSummary, resp, forReq) lastEventID = eventID // If the connection was closed by the client, we're done. @@ -1312,7 +1315,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent newResp, err := c.reconnect(lastEventID) if err != nil { // All reconnection attempts failed: fail the connection. - c.fail(err) + c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, err)) return } resp = newResp @@ -1323,7 +1326,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent } if resp.StatusCode < 200 || resp.StatusCode >= 300 { resp.Body.Close() - c.fail(fmt.Errorf("failed to reconnect: %v", http.StatusText(resp.StatusCode))) + c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, http.StatusText(resp.StatusCode))) return } // Reconnection was successful. Continue the loop with the new response. @@ -1334,7 +1337,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent // incoming channel. It returns the ID of the last processed event and a flag // indicating if the connection was closed by the client. If resp is nil, it // returns "", false. -func (c *streamableClientConn) processStream(resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) { +func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) { defer resp.Body.Close() for evt, err := range scanEvents(resp.Body) { if err != nil { @@ -1347,7 +1350,7 @@ func (c *streamableClientConn) processStream(resp *http.Response, forReq *jsonrp msg, err := jsonrpc.DecodeMessage(evt.Data) if err != nil { - c.fail(fmt.Errorf("failed to decode event: %v", err)) + c.fail(fmt.Errorf("%s: failed to decode event: %v", requestSummary, err)) return "", true } diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index ee89df5a..fe87b21c 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -6,11 +6,14 @@ package mcp import ( "context" + "fmt" "io" "net/http" "net/http/httptest" + "strings" "sync" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" @@ -183,3 +186,90 @@ func TestStreamableClientTransportLifecycle(t *testing.T) { t.Errorf("mismatch (-want, +got):\n%s", diff) } } + +func TestStreamableClientGETHandling(t *testing.T) { + ctx := context.Background() + + tests := []struct { + status int + wantErrorContaining string + }{ + {http.StatusOK, ""}, + {http.StatusMethodNotAllowed, ""}, + {http.StatusBadRequest, "hanging GET"}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("status=%d", test.status), func(t *testing.T) { + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + status: test.status, + wantProtocolVersion: latestProtocolVersion, + }, + {"POST", "123", methodListTools}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, resp(2, &ListToolsResult{Tools: []*Tool{}}, nil)), + optional: true, + }, + {"DELETE", "123", ""}: {optional: true}, + }, + } + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + + // wait for all required requests to be handled, with exponential + // backoff. + start := time.Now() + delay := 1 * time.Millisecond + for range 10 { + if len(fake.missingRequests()) == 0 { + break + } + time.Sleep(delay) + delay *= 2 + } + if missing := fake.missingRequests(); len(missing) > 0 { + t.Errorf("did not receive expected requests after %s: %v", time.Since(start), missing) + } + + _, err = session.ListTools(ctx, nil) + if (err != nil) != (test.wantErrorContaining != "") { + t.Errorf("After initialization, got error %v, want %v", err, test.wantErrorContaining) + } else if err != nil { + if !strings.Contains(err.Error(), test.wantErrorContaining) { + t.Errorf("After initialization, got error %s, want containing %q", err, test.wantErrorContaining) + } + } + + if err := session.Close(); err != nil { + t.Errorf("closing session: %v", err) + } + }) + } +} diff --git a/mcp/transport.go b/mcp/transport.go index fac640a6..5c7ca130 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -194,7 +194,7 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params err := call.Await(ctx, result) switch { case errors.Is(err, jsonrpc2.ErrClientClosing), errors.Is(err, jsonrpc2.ErrServerClosing): - return fmt.Errorf("calling %q: %w", method, ErrConnectionClosed) + return fmt.Errorf("%w: calling %q: %v", ErrConnectionClosed, method, err) case ctx.Err() != nil: // Notify the peer of cancellation. err := conn.Notify(xcontext.Detach(ctx), notificationCancelled, &CancelledParams{ From ee5141621c346d52cb2c61a357845b4a2d33b2fd Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 4 Sep 2025 17:39:05 +0000 Subject: [PATCH 172/221] mcp: be strict about returning the Mcp-Session-Id header Rather than returning the Mcp-Session-Id header for all responses, only return it from initialize, per the spec. Fixes #412 --- mcp/streamable.go | 19 +++++++++++----- mcp/streamable_test.go | 49 +++++++++++++++++++----------------------- 2 files changed, 36 insertions(+), 32 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 25efe31a..1eef9a74 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -424,7 +424,7 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er // It is always text/event-stream, since it must carry arbitrarily many // messages. var err error - t.connection.streams[""], err = t.connection.newStream(ctx, "", false) + t.connection.streams[""], err = t.connection.newStream(ctx, "", false, false) if err != nil { return nil, err } @@ -485,6 +485,10 @@ type stream struct { // an empty string is used for messages that don't correlate with an incoming request. id StreamID + // If isInitialize is set, the stream is in response to an initialize request, + // and therefore should include the session ID header. + isInitialize bool + // jsonResponse records whether this stream should respond with application/json // instead of text/event-stream. // @@ -513,12 +517,13 @@ type stream struct { requests map[jsonrpc.ID]struct{} } -func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, jsonResponse bool) (*stream, error) { +func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, isInitialize, jsonResponse bool) (*stream, error) { if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil { return nil, err } return &stream{ id: id, + isInitialize: isInitialize, jsonResponse: jsonResponse, requests: make(map[jsonrpc.ID]struct{}), }, nil @@ -647,6 +652,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } requests := make(map[jsonrpc.ID]struct{}) tokenInfo := auth.TokenInfoFromContext(req.Context()) + isInitialize := false for _, msg := range incoming { if jreq, ok := msg.(*jsonrpc.Request); ok { // Preemptively check that this is a valid request, so that we can fail @@ -656,6 +662,9 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques http.Error(w, err.Error(), http.StatusBadRequest) return } + if jreq.Method == methodInitialize { + isInitialize = true + } jreq.Extra = &RequestExtra{ TokenInfo: tokenInfo, Header: req.Header, @@ -672,7 +681,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // notifications or server->client requests made in the course of handling. // Update accounting for this incoming payload. if len(requests) > 0 { - stream, err = c.newStream(req.Context(), StreamID(randText()), c.jsonResponse) + stream, err = c.newStream(req.Context(), StreamID(randText()), isInitialize, c.jsonResponse) if err != nil { http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) return @@ -708,7 +717,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter, req *http.Request) { w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Content-Type", "application/json") - if c.sessionID != "" { + if c.sessionID != "" && stream.isInitialize { w.Header().Set(sessionIDHeader, c.sessionID) } @@ -747,7 +756,7 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] w.Header().Set("Connection", "keep-alive") - if c.sessionID != "" { + if c.sessionID != "" && stream.isInitialize { w.Header().Set(sessionIDHeader, c.sessionID) } if persistent { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 2963a04d..3f897ba0 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -133,7 +133,7 @@ func TestStreamableTransports(t *testing.T) { defer session.Close() sid := session.ID() if sid == "" { - t.Error("empty session ID") + t.Fatalf("empty session ID") } if g, w := session.mcpConn.(*streamableClientConn).initializedResult.ProtocolVersion, latestProtocolVersion; g != w { t.Fatalf("got protocol version %q, want %q", g, w) @@ -475,6 +475,8 @@ func resp(id int64, result any, err error) *jsonrpc.Response { } } +var () + func TestStreamableServerTransport(t *testing.T) { // This test checks detailed behavior of the streamable server transport, by // faking the behavior of a streamable client using a sequence of HTTP @@ -502,7 +504,6 @@ func TestStreamableServerTransport(t *testing.T) { method: "POST", messages: []jsonrpc.Message{initializedMsg}, wantStatusCode: http.StatusAccepted, - wantSessionID: false, // TODO: should this be true? } tests := []struct { @@ -520,7 +521,6 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{}, nil)}, - wantSessionID: true, }, }, }, @@ -535,14 +535,12 @@ func TestStreamableServerTransport(t *testing.T) { headers: http.Header{"Accept": {"text/plain", "application/*"}}, messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusBadRequest, // missing text/event-stream - wantSessionID: false, }, { method: "POST", headers: http.Header{"Accept": {"text/event-stream"}}, messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusBadRequest, // missing application/json - wantSessionID: false, }, { method: "POST", @@ -550,7 +548,6 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, - wantSessionID: true, }, { method: "POST", @@ -558,7 +555,6 @@ func TestStreamableServerTransport(t *testing.T) { messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, - wantSessionID: true, }, }, }, @@ -598,7 +594,6 @@ func TestStreamableServerTransport(t *testing.T) { req(0, "notifications/progress", &ProgressNotificationParams{}), resp(2, &CallToolResult{}, nil), }, - wantSessionID: true, }, }, }, @@ -620,7 +615,6 @@ func TestStreamableServerTransport(t *testing.T) { resp(1, &ListRootsResult{}, nil), }, wantStatusCode: http.StatusAccepted, - wantSessionID: false, }, { method: "POST", @@ -632,7 +626,6 @@ func TestStreamableServerTransport(t *testing.T) { req(1, "roots/list", &ListRootsParams{}), resp(2, &CallToolResult{}, nil), }, - wantSessionID: true, }, }, }, @@ -663,7 +656,6 @@ func TestStreamableServerTransport(t *testing.T) { resp(1, &ListRootsResult{}, nil), }, wantStatusCode: http.StatusAccepted, - wantSessionID: false, }, { method: "GET", @@ -674,7 +666,6 @@ func TestStreamableServerTransport(t *testing.T) { req(0, "notifications/progress", &ProgressNotificationParams{}), req(1, "roots/list", &ListRootsParams{}), }, - wantSessionID: true, }, { method: "POST", @@ -685,7 +676,6 @@ func TestStreamableServerTransport(t *testing.T) { wantMessages: []jsonrpc.Message{ resp(2, &CallToolResult{}, nil), }, - wantSessionID: true, }, { method: "DELETE", @@ -724,7 +714,6 @@ func TestStreamableServerTransport(t *testing.T) { wantMessages: []jsonrpc.Message{resp(2, nil, &jsonrpc2.WireError{ Message: `method "tools/call" is invalid during session initialization`, })}, - wantSessionID: true, // TODO: this is probably wrong; we don't have a valid session }, }, }, @@ -951,7 +940,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, return "", 0, nil, fmt.Errorf("creating request: %w", err) } if sessionID != "" { - req.Header.Set("Mcp-Session-Id", sessionID) + req.Header.Set(sessionIDHeader, sessionID) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") @@ -963,7 +952,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, } defer resp.Body.Close() - newSessionID := resp.Header.Get("Mcp-Session-Id") + newSessionID := resp.Header.Get(sessionIDHeader) contentType := resp.Header.Get("Content-Type") var respBody []byte @@ -1079,6 +1068,15 @@ func TestEventID(t *testing.T) { } func TestStreamableStateless(t *testing.T) { + initReq := req(1, methodInitialize, &InitializeParams{}) + initResp := resp(1, &InitializeResult{ + Capabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, + }, + ProtocolVersion: latestProtocolVersion, + ServerInfo: &Implementation{Name: "test", Version: "v1.0.0"}, + }, nil) // This version of sayHi expects // that request from our client). sayHi := func(ctx context.Context, req *CallToolRequest, args hiParams) (*CallToolResult, any, error) { @@ -1092,17 +1090,22 @@ func TestStreamableStateless(t *testing.T) { AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) requests := []streamableRequest{ + { + method: "POST", + messages: []jsonrpc.Message{initReq}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{initResp}, + wantSessionID: false, // sessionless + }, { method: "POST", wantStatusCode: http.StatusOK, messages: []jsonrpc.Message{req(1, "tools/list", struct{}{})}, wantBodyContaining: "greet", - wantSessionID: false, }, { method: "GET", wantStatusCode: http.StatusMethodNotAllowed, - wantSessionID: false, }, { method: "POST", @@ -1116,7 +1119,6 @@ func TestStreamableStateless(t *testing.T) { StructuredContent: json.RawMessage("null"), }, nil), }, - wantSessionID: false, }, { method: "POST", @@ -1130,7 +1132,6 @@ func TestStreamableStateless(t *testing.T) { StructuredContent: json.RawMessage("null"), }, nil), }, - wantSessionID: false, }, } @@ -1166,13 +1167,7 @@ func TestStreamableStateless(t *testing.T) { // // This can be used by tools to look up application state preserved across // subsequent requests. - for i, req := range requests { - // Now, we want a session for all (valid) requests. - if req.wantStatusCode != http.StatusMethodNotAllowed { - req.wantSessionID = true - } - requests[i] = req - } + requests[0].wantSessionID = true // now expect a session ID for initialize statelessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ Stateless: true, }) From 9d34aff732740c09ade9f39dc5f61466ff33d1f2 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 4 Sep 2025 19:58:12 +0000 Subject: [PATCH 173/221] mcp: avoid "null" in structuredContent Avoid a typed nil in structuredContent (yet another typed nil), leading to "null" on the wire. Fixes #417 --- mcp/server.go | 28 ++++++++------- mcp/streamable_test.go | 6 ++-- mcp/testdata/conformance/server/tools.txtar | 38 ++++++++++++++++++++- 3 files changed, 54 insertions(+), 18 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 19d902ff..508552e5 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -267,19 +267,21 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan outval = elemZero } } - outbytes, err := json.Marshal(outval) - if err != nil { - return nil, fmt.Errorf("marshaling output: %w", err) - } - res.StructuredContent = json.RawMessage(outbytes) // avoid a second marshal over the wire - - // If the Content field isn't being used, return the serialized JSON in a - // TextContent block, as the spec suggests: - // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content. - if res.Content == nil { - res.Content = []Content{&TextContent{ - Text: string(outbytes), - }} + if outval != nil { + outbytes, err := json.Marshal(outval) + if err != nil { + return nil, fmt.Errorf("marshaling output: %w", err) + } + res.StructuredContent = json.RawMessage(outbytes) // avoid a second marshal over the wire + + // If the Content field isn't being used, return the serialized JSON in a + // TextContent block, as the spec suggests: + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content. + if res.Content == nil { + res.Content = []Content{&TextContent{ + Text: string(outbytes), + }} + } } return res, nil } // end of handler diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 3f897ba0..1e7a63ce 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1115,8 +1115,7 @@ func TestStreamableStateless(t *testing.T) { }, wantMessages: []jsonrpc.Message{ resp(2, &CallToolResult{ - Content: []Content{&TextContent{Text: "hi World"}}, - StructuredContent: json.RawMessage("null"), + Content: []Content{&TextContent{Text: "hi World"}}, }, nil), }, }, @@ -1128,8 +1127,7 @@ func TestStreamableStateless(t *testing.T) { }, wantMessages: []jsonrpc.Message{ resp(2, &CallToolResult{ - Content: []Content{&TextContent{Text: "hi foo"}}, - StructuredContent: json.RawMessage("null"), + Content: []Content{&TextContent{Text: "hi foo"}}, }, nil), }, }, diff --git a/mcp/testdata/conformance/server/tools.txtar b/mcp/testdata/conformance/server/tools.txtar index 870e9ea5..c39e3ec9 100644 --- a/mcp/testdata/conformance/server/tools.txtar +++ b/mcp/testdata/conformance/server/tools.txtar @@ -21,11 +21,15 @@ structured "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } } } -{"jsonrpc":"2.0", "method": "notifications/initialized"} +{ "jsonrpc":"2.0", "method": "notifications/initialized" } { "jsonrpc": "2.0", "id": 2, "method": "tools/list" } { "jsonrpc": "2.0", "id": 3, "method": "resources/list" } { "jsonrpc": "2.0", "id": 4, "method": "prompts/list" } { "jsonrpc": "2.0", "id": 5, "method": "tools/call" } +{ "jsonrpc": "2.0", "id": 6, "method": "tools/call", "params": {"name": "greet", "arguments": {"name": "you"} } } +{ "jsonrpc": "2.0", "id": 1, "result": {} } +{ "jsonrpc": "2.0", "id": 7, "method": "tools/call", "params": {"name": "structured", "arguments": {"In": "input"} } } + -- server -- { "jsonrpc": "2.0", @@ -119,3 +123,35 @@ structured "message": "invalid request: missing required \"params\"" } } +{ + "jsonrpc": "2.0", + "id": 1, + "method": "ping" +} +{ + "jsonrpc": "2.0", + "id": 6, + "result": { + "content": [ + { + "type": "text", + "text": "hi you" + } + ] + } +} +{ + "jsonrpc": "2.0", + "id": 7, + "result": { + "content": [ + { + "type": "text", + "text": "{\"Out\":\"Ack input\"}" + } + ], + "structuredContent": { + "Out": "Ack input" + } + } +} From b774809086db4bcf9262d40049ff82855da5ab5a Mon Sep 17 00:00:00 2001 From: Kason Braley <59150626+KasonBraley@users.noreply.github.com> Date: Fri, 5 Sep 2025 04:58:26 -0700 Subject: [PATCH 174/221] examples: Update and fix outdated (#420) Update examples/server/rate-limiting to use the same replace directive strategy as the auth-middleware example, and fix the breakages. Also remove the unneeded type parameters in examples/client/middleware. --- examples/client/middleware/main.go | 4 +-- examples/server/rate-limiting/go.mod | 11 +++++--- examples/server/rate-limiting/go.sum | 6 +++-- examples/server/rate-limiting/main.go | 36 +++++++++++++-------------- 4 files changed, 32 insertions(+), 25 deletions(-) diff --git a/examples/client/middleware/main.go b/examples/client/middleware/main.go index 6ae87df0..9b6d1bb3 100644 --- a/examples/client/middleware/main.go +++ b/examples/client/middleware/main.go @@ -17,10 +17,10 @@ var nextProgressToken atomic.Int64 // from the client. func main() { c := mcp.NewClient(&mcp.Implementation{Name: "test"}, nil) - c.AddSendingMiddleware(addProgressToken[*mcp.ClientSession]) + c.AddSendingMiddleware(addProgressToken) } -func addProgressToken[S mcp.Session](h mcp.MethodHandler) mcp.MethodHandler { +func addProgressToken(h mcp.MethodHandler) mcp.MethodHandler { return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) { if rp, ok := req.GetParams().(mcp.RequestParams); ok { rp.SetProgressToken(nextProgressToken.Add(1)) diff --git a/examples/server/rate-limiting/go.mod b/examples/server/rate-limiting/go.mod index 5b76b16e..2b4b00dd 100644 --- a/examples/server/rate-limiting/go.mod +++ b/examples/server/rate-limiting/go.mod @@ -2,9 +2,14 @@ module github.com/modelcontextprotocol/go-sdk/examples/rate-limiting go 1.23.0 -toolchain go1.24.4 - require ( - github.com/modelcontextprotocol/go-sdk v0.1.0 + github.com/modelcontextprotocol/go-sdk v0.3.0 golang.org/x/time v0.12.0 ) + +require ( + github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect +) + +replace github.com/modelcontextprotocol/go-sdk => ../../../ diff --git a/examples/server/rate-limiting/go.sum b/examples/server/rate-limiting/go.sum index d73f0a54..16b70ef7 100644 --- a/examples/server/rate-limiting/go.sum +++ b/examples/server/rate-limiting/go.sum @@ -1,7 +1,9 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/modelcontextprotocol/go-sdk v0.1.0 h1:ItzbFWYNt4EHcUrScX7P8JPASn1FVYb29G773Xkl+IU= -github.com/modelcontextprotocol/go-sdk v0.1.0/go.mod h1:DcXfbr7yl7e35oMpzHfKw2nUYRjhIGS2uou/6tdsTB0= +github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76 h1:mBlBwtDebdDYr+zdop8N62a44g+Nbv7o2KjWyS1deR4= +github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= diff --git a/examples/server/rate-limiting/main.go b/examples/server/rate-limiting/main.go index c3265c4c..e107183f 100644 --- a/examples/server/rate-limiting/main.go +++ b/examples/server/rate-limiting/main.go @@ -18,13 +18,13 @@ import ( // GlobalRateLimiterMiddleware creates a middleware that applies a global rate limit. // Every request attempting to pass through will try to acquire a token. // If a token cannot be acquired immediately, the request will be rejected. -func GlobalRateLimiterMiddleware[S mcp.Session](limiter *rate.Limiter) mcp.Middleware[S] { - return func(next mcp.MethodHandler[S]) mcp.MethodHandler[S] { - return func(ctx context.Context, session S, method string, params mcp.Params) (mcp.Result, error) { +func GlobalRateLimiterMiddleware(limiter *rate.Limiter) mcp.Middleware { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { if !limiter.Allow() { return nil, errors.New("JSON RPC overloaded") } - return next(ctx, session, method, params) + return next(ctx, method, req) } } } @@ -32,40 +32,40 @@ func GlobalRateLimiterMiddleware[S mcp.Session](limiter *rate.Limiter) mcp.Middl // PerMethodRateLimiterMiddleware creates a middleware that applies rate limiting // on a per-method basis. // Methods not specified in limiters will not be rate limited by this middleware. -func PerMethodRateLimiterMiddleware[S mcp.Session](limiters map[string]*rate.Limiter) mcp.Middleware[S] { - return func(next mcp.MethodHandler[S]) mcp.MethodHandler[S] { - return func(ctx context.Context, session S, method string, params mcp.Params) (mcp.Result, error) { +func PerMethodRateLimiterMiddleware(limiters map[string]*rate.Limiter) mcp.Middleware { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { if limiter, ok := limiters[method]; ok { if !limiter.Allow() { return nil, errors.New("JSON RPC overloaded") } } - return next(ctx, session, method, params) + return next(ctx, method, req) } } } // PerSessionRateLimiterMiddleware creates a middleware that applies rate limiting // on a per-session basis for receiving requests. -func PerSessionRateLimiterMiddleware[S mcp.Session](limit rate.Limit, burst int) mcp.Middleware[S] { +func PerSessionRateLimiterMiddleware(limit rate.Limit, burst int) mcp.Middleware { // A map to store limiters, keyed by the session ID. var ( sessionLimiters = make(map[string]*rate.Limiter) mu sync.Mutex ) - return func(next mcp.MethodHandler[S]) mcp.MethodHandler[S] { - return func(ctx context.Context, session S, method string, params mcp.Params) (mcp.Result, error) { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { // It's possible that session.ID() may be empty at this point in time // for some transports (e.g., stdio) or until the MCP initialize handshake // has completed. - sessionID := session.ID() + sessionID := req.GetSession().ID() if sessionID == "" { // In this situation, you could apply a single global identifier // if session ID is empty or bypass the rate limiter. // In this example, we bypass the rate limiter. log.Printf("Warning: Session ID is empty for method %q. Skipping per-session rate limiting.", method) - return next(ctx, session, method, params) // Skip limiting if ID is unavailable + return next(ctx, method, req) // Skip limiting if ID is unavailable } mu.Lock() limiter, ok := sessionLimiters[sessionID] @@ -77,19 +77,19 @@ func PerSessionRateLimiterMiddleware[S mcp.Session](limit rate.Limit, burst int) if !limiter.Allow() { return nil, errors.New("JSON RPC overloaded") } - return next(ctx, session, method, params) + return next(ctx, method, req) } } } func main() { - server := mcp.NewServer("greeter1", "v0.0.1", nil) - server.AddReceivingMiddleware(GlobalRateLimiterMiddleware[*mcp.ServerSession](rate.NewLimiter(rate.Every(time.Second/5), 10))) - server.AddReceivingMiddleware(PerMethodRateLimiterMiddleware[*mcp.ServerSession](map[string]*rate.Limiter{ + server := mcp.NewServer(&mcp.Implementation{Name: "greeter1", Version: "v0.0.1"}, nil) + server.AddReceivingMiddleware(GlobalRateLimiterMiddleware(rate.NewLimiter(rate.Every(time.Second/5), 10))) + server.AddReceivingMiddleware(PerMethodRateLimiterMiddleware(map[string]*rate.Limiter{ "callTool": rate.NewLimiter(rate.Every(time.Second), 5), // once a second with a burst up to 5 "listTools": rate.NewLimiter(rate.Every(time.Minute), 20), // once a minute with a burst up to 20 })) - server.AddReceivingMiddleware(PerSessionRateLimiterMiddleware[*mcp.ServerSession](rate.Every(time.Second/5), 10)) + server.AddReceivingMiddleware(PerSessionRateLimiterMiddleware(rate.Every(time.Second/5), 10)) // Run Server logic. log.Println("MCP Server instance created with Middleware (but not running).") log.Println("This example demonstrates configuration, not live interaction.") From 5ccd97fafcaed9725b5e26ca84f9023a110531c0 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 5 Sep 2025 01:10:52 +0000 Subject: [PATCH 175/221] mcp: allow requests before notifications/initialized Upon a closer reading of the spec, it is OK for the server to response to client requests before it has received notifications/initialized (as long as it has received initialize). This effectively rolls back the fix from #225. Some hooks are left for enforcing strictness around server->client requests prior to initialized. These will be revisited in subsequent CLs. For #395 --- mcp/mcp_test.go | 2 +- mcp/server.go | 38 ++++++++++++++++++- mcp/shared.go | 5 ++- .../conformance/server/lifecycle.txtar | 14 +++++-- 4 files changed, 53 insertions(+), 6 deletions(-) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 28a8b974..aee06a31 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -1415,7 +1415,7 @@ func TestElicitationCapabilityDeclaration(t *testing.T) { RequestedSchema: &jsonschema.Schema{Type: "object"}, }) if err != nil { - t.Errorf("elicitation should work when capability is declared, got error: %v", err) + t.Fatalf("elicitation should work when capability is declared, got error: %v", err) } if result.Action != "cancel" { t.Errorf("got action %q, want %q", result.Action, "cancel") diff --git a/mcp/server.go b/mcp/server.go index 508552e5..df97b58b 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -844,6 +844,33 @@ func (ss *ServerSession) updateState(mut func(*ServerSessionState)) { } } +// hasInitialized reports whether the server has received the initialized +// notification. +// +// TODO(findleyr): use this to prevent change notifications. +func (ss *ServerSession) hasInitialized() bool { + ss.mu.Lock() + defer ss.mu.Unlock() + return ss.state.InitializedParams != nil +} + +// checkInitialized returns a formatted error if the server has not yet +// received the initialized notification. +func (ss *ServerSession) checkInitialized(method string) error { + if !ss.hasInitialized() { + // TODO(rfindley): enable this check. + // Right now is is flaky, because server tests don't await the initialized notification. + // Perhaps requests should simply block until they have received the initialized notification + + // if strings.HasPrefix(method, "notifications/") { + // return fmt.Errorf("must not send %q before %q is received", method, notificationInitialized) + // } else { + // return fmt.Errorf("cannot call %q before %q is received", method, notificationInitialized) + // } + } + return nil +} + func (ss *ServerSession) ID() string { if c, ok := ss.mcpConn.(hasSessionID); ok { return c.SessionID() @@ -859,11 +886,17 @@ func (ss *ServerSession) Ping(ctx context.Context, params *PingParams) error { // ListRoots lists the client roots. func (ss *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) (*ListRootsResult, error) { + if err := ss.checkInitialized(methodListRoots); err != nil { + return nil, err + } return handleSend[*ListRootsResult](ctx, methodListRoots, newServerRequest(ss, orZero[Params](params))) } // CreateMessage sends a sampling request to the client. func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessageParams) (*CreateMessageResult, error) { + if err := ss.checkInitialized(methodCreateMessage); err != nil { + return nil, err + } if params == nil { params = &CreateMessageParams{Messages: []*SamplingMessage{}} } @@ -877,6 +910,9 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag // Elicit sends an elicitation request to the client asking for user input. func (ss *ServerSession) Elicit(ctx context.Context, params *ElicitParams) (*ElicitResult, error) { + if err := ss.checkInitialized(methodElicit); err != nil { + return nil, err + } return handleSend[*ElicitResult](ctx, methodElicit, newServerRequest(ss, orZero[Params](params))) } @@ -978,7 +1014,7 @@ func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn } // handle invokes the method described by the given JSON RPC request. func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { ss.mu.Lock() - initialized := ss.state.InitializedParams != nil + initialized := ss.state.InitializeParams != nil ss.mu.Unlock() // From the spec: diff --git a/mcp/shared.go b/mcp/shared.go index e3ad6ff7..29903926 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -346,9 +346,12 @@ func notifySessions[S Session, P Params](sessions []S, method string, params P) if sessions == nil { return } - // TODO: make this timeout configurable, or call Notify asynchronously. + // TODO: make this timeout configurable, or call handleNotify asynchronously. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + + // TODO: there's a potential spec violation here, when the feature list + // changes before the session (client or server) is initialized. for _, s := range sessions { req := newRequest(s, params) if err := handleNotify(ctx, method, req); err != nil { diff --git a/mcp/testdata/conformance/server/lifecycle.txtar b/mcp/testdata/conformance/server/lifecycle.txtar index eba287e0..0a8cf34b 100644 --- a/mcp/testdata/conformance/server/lifecycle.txtar +++ b/mcp/testdata/conformance/server/lifecycle.txtar @@ -5,6 +5,7 @@ See also modelcontextprotocol/go-sdk#225. -- client -- { "jsonrpc":"2.0", "method": "notifications/initialized" } +{ "jsonrpc": "2.0", "id": 2, "method": "tools/list" } { "jsonrpc": "2.0", "id": 1, @@ -21,6 +22,14 @@ See also modelcontextprotocol/go-sdk#225. { "jsonrpc": "2.0", "id": 3, "method": "tools/list" } -- server -- +{ + "jsonrpc": "2.0", + "id": 2, + "error": { + "code": 0, + "message": "method \"tools/list\" is invalid during session initialization" + } +} { "jsonrpc": "2.0", "id": 1, @@ -43,9 +52,8 @@ See also modelcontextprotocol/go-sdk#225. { "jsonrpc": "2.0", "id": 2, - "error": { - "code": 0, - "message": "method \"tools/list\" is invalid during session initialization" + "result": { + "tools": [] } } { From 40b6bd30f1fbaf520d33d0c941d6cd6b0b45c6b7 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Fri, 5 Sep 2025 14:11:11 -0400 Subject: [PATCH 176/221] all: add examples of customizing tool schemas, and clarify documentation (#421) Add example/server/toolschemas, which demonstrates various ways to customize tool schemas, and illustrates the differences between ToolHandler and ToolHandlerFor. Fixes #385 Fixes #386 For #368 --- README.md | 105 ++++++++++---------- examples/server/toolschemas/main.go | 146 ++++++++++++++++++++++++++++ internal/readme/README.src.md | 46 +++++---- internal/readme/server/server.go | 13 +-- mcp/protocol.go | 39 ++++++-- mcp/server.go | 20 ++-- mcp/tool.go | 40 +++++++- 7 files changed, 309 insertions(+), 100 deletions(-) create mode 100644 examples/server/toolschemas/main.go diff --git a/README.md b/README.md index 126aa52e..606f579a 100644 --- a/README.md +++ b/README.md @@ -19,18 +19,6 @@ software development kit (SDK) for the Model Context Protocol (MCP). > changes. We aim to tag v1.0.0 in September, 2025. See > https://github.com/modelcontextprotocol/go-sdk/issues/328 for details. -## Design - -The design doc for this SDK is at [design.md](./design/design.md), which was -initially reviewed at -[modelcontextprotocol/discussions/364](https://github.com/orgs/modelcontextprotocol/discussions/364). - -Further design discussion should occur in -[issues](https://github.com/modelcontextprotocol/go-sdk/issues) (for concrete -proposals) or -[discussions](https://github.com/modelcontextprotocol/go-sdk/discussions) for -open-ended discussion. See [CONTRIBUTING.md](/CONTRIBUTING.md) for details. - ## Package documentation The SDK consists of two importable packages: @@ -42,12 +30,53 @@ The SDK consists of two importable packages: - The [`github.com/modelcontextprotocol/go-sdk/jsonrpc`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonrpc) package is for users implementing their own transports. - +- The + [`github.com/modelcontextprotocol/go-sdk/auth`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth) + package provides some primitives for supporting oauth. + +## Getting started + +To get started creating an MCP server, create an `mcp.Server` instance, add +features to it, and then run it over an `mcp.Transport`. For example, this +server adds a single simple tool, and then connects it to clients over +stdin/stdout: + +```go +package main + +import ( + "context" + "log" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) -## Example +type Input struct { + Name string `json:"name" jsonschema:"the name of the person to greet"` +} + +type Output struct { + Greeting string `json:"greeting" jsonschema:"the greeting to tell to the user"` +} + +func SayHi(ctx context.Context, req *mcp.CallToolRequest, input Input) (*mcp.CallToolResult, Output, error) { + return nil, Output{Greeting: "Hi " + input.Name}, nil +} + +func main() { + // Create a server with a single tool. + server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v1.0.0"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) + // Run the server over stdin/stdout, until the client disconnects + if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { + log.Fatal(err) + } +} +``` -In this example, an MCP client communicates with an MCP server running in a -sidecar process: +To communicate with that server, we can similarly create an `mcp.Client` and +connect it to the corresponding server, by running the server command and +communicating over its stdin/stdout: ```go package main @@ -92,42 +121,20 @@ func main() { } ``` -Here's an example of the corresponding server component, which communicates -with its client over stdin/stdout: - -```go -package main - -import ( - "context" - "log" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -type HiParams struct { - Name string `json:"name" jsonschema:"the name of the person to greet"` -} - -func SayHi(ctx context.Context, req *mcp.CallToolRequest, args HiParams) (*mcp.CallToolResult, any, error) { - return &mcp.CallToolResult{ - Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + args.Name}}, - }, nil, nil -} +The [`examples/`](/examples/) directory contains more example clients and +servers. -func main() { - // Create a server with a single tool. - server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v1.0.0"}, nil) +## Design - mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) - // Run the server over stdin/stdout, until the client disconnects - if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { - log.Fatal(err) - } -} -``` +The design doc for this SDK is at [design.md](./design/design.md), which was +initially reviewed at +[modelcontextprotocol/discussions/364](https://github.com/orgs/modelcontextprotocol/discussions/364). -The [`examples/`](/examples/) directory contains more example clients and servers. +Further design discussion should occur in +[issues](https://github.com/modelcontextprotocol/go-sdk/issues) (for concrete +proposals) or +[discussions](https://github.com/modelcontextprotocol/go-sdk/discussions) for +open-ended discussion. See [CONTRIBUTING.md](/CONTRIBUTING.md) for details. ## Acknowledgements diff --git a/examples/server/toolschemas/main.go b/examples/server/toolschemas/main.go new file mode 100644 index 00000000..36ee0495 --- /dev/null +++ b/examples/server/toolschemas/main.go @@ -0,0 +1,146 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The toolschemas example demonstrates how to create tools using both the +// low-level [ToolHandler] and high level [ToolHandlerFor], as well as how to +// customize schemas in both cases. +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// Input is the input into all the tools handlers below. +type Input struct { + Name string `json:"name" jsonschema:"the person to greet"` +} + +// Output is the structured output of the tool. +// +// Not every tool needs to have structured output. +type Output struct { + Greeting string `json:"greeting" jsonschema:"the greeting to send to the user"` +} + +// simpleGreeting is an [mcp.ToolHandlerFor] that only cares about input and output. +func simpleGreeting(_ context.Context, _ *mcp.CallToolRequest, input Input) (*mcp.CallToolResult, Output, error) { + return nil, Output{"Hi " + input.Name}, nil +} + +// manualGreeter handles the parsing and validation of input and output manually. +// +// Therfore, it needs to close over its resolved schemas, to use them in +// validation. +type manualGreeter struct { + inputSchema *jsonschema.Resolved + outputSchema *jsonschema.Resolved +} + +func (t *manualGreeter) greet(_ context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // errf produces a 'tool error', embedding the error in a CallToolResult. + errf := func(format string, args ...any) *mcp.CallToolResult { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: fmt.Sprintf(format, args...)}}, + IsError: true, + } + } + // Handle the parsing and validation of input and output. + // + // Note that errors here are treated as tool errors, not protocol errors. + var input Input + if err := json.Unmarshal(req.Params.Arguments, &input); err != nil { + return errf("failed to unmarshal arguments: %v", err), nil + } + if err := t.inputSchema.Validate(input); err != nil { + return errf("invalid input: %v", err), nil + } + output := Output{Greeting: "Hi " + input.Name} + if err := t.outputSchema.Validate(output); err != nil { + return errf("tool produced invalid output: %v", err), nil + } + outputJSON, err := json.Marshal(output) + if err != nil { + return errf("output failed to marshal: %v", err), nil + } + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: string(outputJSON)}}, + StructuredContent: output, + }, nil +} + +func main() { + server := mcp.NewServer(&mcp.Implementation{Name: "greeter"}, nil) + + // Add the 'greeting' tool in a few different ways. + + // First, we can just use [mcp.AddTool], and get the out-of-the-box handling + // it provides: + mcp.AddTool(server, &mcp.Tool{Name: "simple greeting"}, simpleGreeting) + + // Next, we can create our schemas entirely manually, and add them using + // [mcp.Server.AddTool]. Since we're working manually, we can add some + // constraints on the length of the name. + // + // We don't need to do all this work: below, we use jsonschema.For to start + // from the default schema. + var ( + manual manualGreeter + err error + ) + inputSchema := &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string", MaxLength: jsonschema.Ptr(10)}, + }, + } + manual.inputSchema, err = inputSchema.Resolve(nil) + if err != nil { + log.Fatal(err) + } + outputSchema := &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "greeting": {Type: "string"}, + }, + } + manual.outputSchema, err = outputSchema.Resolve(nil) + if err != nil { + log.Fatal(err) + } + server.AddTool(&mcp.Tool{ + Name: "manual greeting", + InputSchema: inputSchema, + OutputSchema: outputSchema, + }, manual.greet) + + // Finally, note that we can also use custom schemas with a ToolHandlerFor. + // We can do this in two ways: by using one of the schema values constructed + // above, or by using jsonschema.For and adjusting the resulting schema. + mcp.AddTool(server, &mcp.Tool{ + Name: "customized greeting 1", + InputSchema: inputSchema, + // OutputSchema will still be derived from Output. + }, simpleGreeting) + + customSchema, err := jsonschema.For[Input](nil) + if err != nil { + log.Fatal(err) + } + customSchema.Properties["name"].MaxLength = jsonschema.Ptr(10) + mcp.AddTool(server, &mcp.Tool{ + Name: "customized greeting 2", + InputSchema: customSchema, + }, simpleGreeting) + + // Now run the server. + if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { + log.Printf("Server failed: %v", err) + } +} diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index 40e0aa19..d128c935 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -18,18 +18,6 @@ software development kit (SDK) for the Model Context Protocol (MCP). > changes. We aim to tag v1.0.0 in September, 2025. See > https://github.com/modelcontextprotocol/go-sdk/issues/328 for details. -## Design - -The design doc for this SDK is at [design.md](./design/design.md), which was -initially reviewed at -[modelcontextprotocol/discussions/364](https://github.com/orgs/modelcontextprotocol/discussions/364). - -Further design discussion should occur in -[issues](https://github.com/modelcontextprotocol/go-sdk/issues) (for concrete -proposals) or -[discussions](https://github.com/modelcontextprotocol/go-sdk/discussions) for -open-ended discussion. See [CONTRIBUTING.md](/CONTRIBUTING.md) for details. - ## Package documentation The SDK consists of two importable packages: @@ -41,21 +29,39 @@ The SDK consists of two importable packages: - The [`github.com/modelcontextprotocol/go-sdk/jsonrpc`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonrpc) package is for users implementing their own transports. - +- The + [`github.com/modelcontextprotocol/go-sdk/auth`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth) + package provides some primitives for supporting oauth. -## Example +## Getting started -In this example, an MCP client communicates with an MCP server running in a -sidecar process: +To get started creating an MCP server, create an `mcp.Server` instance, add +features to it, and then run it over an `mcp.Transport`. For example, this +server adds a single simple tool, and then connects it to clients over +stdin/stdout: + +%include server/server.go - + +To communicate with that server, we can similarly create an `mcp.Client` and +connect it to the corresponding server, by running the server command and +communicating over its stdin/stdout: %include client/client.go - -Here's an example of the corresponding server component, which communicates -with its client over stdin/stdout: +The [`examples/`](/examples/) directory contains more example clients and +servers. -%include server/server.go - +## Design -The [`examples/`](/examples/) directory contains more example clients and servers. +The design doc for this SDK is at [design.md](./design/design.md), which was +initially reviewed at +[modelcontextprotocol/discussions/364](https://github.com/orgs/modelcontextprotocol/discussions/364). + +Further design discussion should occur in +[issues](https://github.com/modelcontextprotocol/go-sdk/issues) (for concrete +proposals) or +[discussions](https://github.com/modelcontextprotocol/go-sdk/discussions) for +open-ended discussion. See [CONTRIBUTING.md](/CONTRIBUTING.md) for details. ## Acknowledgements diff --git a/internal/readme/server/server.go b/internal/readme/server/server.go index aff5fcd0..e9996027 100644 --- a/internal/readme/server/server.go +++ b/internal/readme/server/server.go @@ -12,20 +12,21 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -type HiParams struct { +type Input struct { Name string `json:"name" jsonschema:"the name of the person to greet"` } -func SayHi(ctx context.Context, req *mcp.CallToolRequest, args HiParams) (*mcp.CallToolResult, any, error) { - return &mcp.CallToolResult{ - Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + args.Name}}, - }, nil, nil +type Output struct { + Greeting string `json:"greeting" jsonschema:"the greeting to tell to the user"` +} + +func SayHi(ctx context.Context, req *mcp.CallToolRequest, input Input) (*mcp.CallToolResult, Output, error) { + return nil, Output{Greeting: "Hi " + input.Name}, nil } func main() { // Create a server with a single tool. server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v1.0.0"}, nil) - mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) // Run the server over stdin/stdout, until the client disconnects if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { diff --git a/mcp/protocol.go b/mcp/protocol.go index 27860659..a8e4817d 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -49,7 +49,9 @@ type CallToolParams struct { Arguments any `json:"arguments,omitempty"` } -// CallToolParamsRaw is passed to tool handlers on the server. +// CallToolParamsRaw is passed to tool handlers on the server. Its arguments +// are not yet unmarshaled (hence "raw"), so that the handlers can perform +// unmarshaling themselves. type CallToolParamsRaw struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. @@ -58,29 +60,48 @@ type CallToolParamsRaw struct { Arguments json.RawMessage `json:"arguments,omitempty"` } -// The server's response to a tool call. +// A CallToolResult is the server's response to a tool call. +// +// The [ToolHandler] and [ToolHandlerFor] handler functions return this result, +// though [ToolHandlerFor] populates much of it automatically as documented at +// each field. type CallToolResult struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` + // A list of content objects that represent the unstructured result of the tool // call. + // + // When using a [ToolHandlerFor] with structured output, if Content is unset + // it will be populated with JSON text content corresponding to the + // structured output value. Content []Content `json:"content"` - // An optional JSON object that represents the structured result of the tool - // call. + + // StructuredContent is an optional value that represents the structured + // result of the tool call. It must marshal to a JSON object. + // + // When using a [ToolHandlerFor] with structured output, you should not + // populate this field. It will be automatically populated with the typed Out + // value. StructuredContent any `json:"structuredContent,omitempty"` - // Whether the tool call ended in an error. + + // IsError reports whether the tool call ended in an error. // // If not set, this is assumed to be false (the call was successful). // - // Any errors that originate from the tool should be reported inside the result - // object, with isError set to true, not as an MCP protocol-level error - // response. Otherwise, the LLM would not be able to see that an error occurred - // and self-correct. + // Any errors that originate from the tool should be reported inside the + // Content field, with IsError set to true, not as an MCP protocol-level + // error response. Otherwise, the LLM would not be able to see that an error + // occurred and self-correct. // // However, any errors in finding the tool, an error indicating that the // server does not support tool calls, or any other exceptional conditions, // should be reported as an MCP error response. + // + // When using a [ToolHandlerFor], this field is automatically set when the + // tool handler returns an error, and the error string is included as text in + // the Content field. IsError bool `json:"isError,omitempty"` } diff --git a/mcp/server.go b/mcp/server.go index df97b58b..466c4459 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -321,22 +321,18 @@ func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) } // AddTool adds a tool and typed tool handler to the server. +// // If the tool's input schema is nil, it is set to the schema inferred from the -// In type parameter, using [jsonschema.For]. The In type parameter must be a +// In type parameter, using [jsonschema.For]. The In type argument must be a // map or a struct, so that its inferred JSON Schema has type "object". // -// For tools that don't return structured output, Out should be 'any'. -// Otherwise, if the tool's output schema is nil the output schema is set to -// the schema inferred from Out, which must be a map or a struct. -// -// The In argument to the handler will contain the unmarshaled arguments from -// CallToolRequest.Params.Arguments. Most users can ignore the [CallToolRequest] -// argument to the handler. +// If the tool's output schema is nil, and the Out type is not 'any', the +// output schema is set to the schema inferred from the Out type argument, +// which also must be a map or struct. // -// The handler's Out return value will be used to populate [CallToolResult.StructuredContent]. -// If the handler returns a non-nil error, [CallToolResult.IsError] will be set to true, -// and [CallToolResult.Content] will be set to the text of the error. -// Most users can ignore the [CallToolResult] return value. +// Unlike [Server.AddTool], AddTool does a lot automatically, and forces tools +// to conform to the MCP spec. See [ToolHandlerFor] for a detailed description +// of this automatic behavior. func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { tt, hh, err := toolForErr(t, h) if err != nil { diff --git a/mcp/tool.go b/mcp/tool.go index 0797700f..ffccbf30 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -14,14 +14,46 @@ import ( ) // A ToolHandler handles a call to tools/call. -// This is a low-level API, for use with [Server.AddTool]. -// Most users will write a [ToolHandlerFor] and install it with [AddTool]. +// +// This is a low-level API, for use with [Server.AddTool]. It does not do any +// pre- or post-processing of the request or result: the params contain raw +// arguments, no input validation is performed, and the result is returned to +// the user as-is, without any validation of the output. +// +// Most users will write a [ToolHandlerFor] and install it with the generic +// [AddTool] function. +// +// If ToolHandler returns an error, it is treated as a protocol error. By +// contrast, [ToolHandlerFor] automatically populates [CallToolResult.IsError] +// and [CallToolResult.Content] accordingly. type ToolHandler func(context.Context, *CallToolRequest) (*CallToolResult, error) // A ToolHandlerFor handles a call to tools/call with typed arguments and results. +// // Use [AddTool] to add a ToolHandlerFor to a server. -// Most users can ignore the [CallToolRequest] argument and [CallToolResult] return value. -type ToolHandlerFor[In, Out any] func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) +// +// Unlike [ToolHandler], [ToolHandlerFor] provides significant functionality +// out of the box, and enforces that the tool conforms to the MCP spec: +// - The In type provides a default input schema for the tool, though it may +// be overridden in [AddTool]. +// - The input value is automatically unmarshaled from req.Params.Arguments. +// - The input value is automatically validated against its input schema. +// Invalid input is rejected before getting to the handler. +// - If the Out type is not the empty interface [any], it provides the +// default output schema for the tool (which again may be overridden in +// [AddTool]). +// - The Out value is used to populate result.StructuredOutput. +// - If [CallToolResult.Content] is unset, it is populated with the JSON +// content of the output. +// - An error result is treated as a tool error, rather than a protocol +// error, and is therefore packed into CallToolResult.Content, with +// [IsError] set. +// +// For these reasons, most users can ignore the [CallToolRequest] argument and +// [CallToolResult] return values entirely. In fact, it is permissible to +// return a nil CallToolResult, if you only care about returning a output value +// or error. The effective result will be populated as described above. +type ToolHandlerFor[In, Out any] func(_ context.Context, request *CallToolRequest, input In) (result *CallToolResult, output Out, _ error) // A serverTool is a tool definition that is bound to a tool handler. type serverTool struct { From 2a00640e1e4a6e20a1e116216227800380d73dfa Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 5 Sep 2025 14:17:21 -0400 Subject: [PATCH 177/221] mcp/internal/oauthex: auth server metadata (#294) Implement the Authorization Server Metadata spec. --- internal/oauthex/auth_meta.go | 145 ++++++++++++++++++ internal/oauthex/auth_meta_test.go | 28 ++++ internal/oauthex/oauth2.go | 61 ++++++++ internal/oauthex/resource_meta.go | 28 +--- .../oauthex/testdata/google-auth-meta.json | 57 +++++++ 5 files changed, 293 insertions(+), 26 deletions(-) create mode 100644 internal/oauthex/auth_meta.go create mode 100644 internal/oauthex/auth_meta_test.go create mode 100644 internal/oauthex/testdata/google-auth-meta.json diff --git a/internal/oauthex/auth_meta.go b/internal/oauthex/auth_meta.go new file mode 100644 index 00000000..1d6deb90 --- /dev/null +++ b/internal/oauthex/auth_meta.go @@ -0,0 +1,145 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements Authorization Server Metadata. +// See https://www.rfc-editor.org/rfc/rfc8414.html. + +package oauthex + +import ( + "context" + "errors" + "fmt" + "net/http" +) + +// AuthServerMeta represents the metadata for an OAuth 2.0 authorization server, +// as defined in [RFC 8414]. +// +// Not supported: +// - signed metadata +// +// [RFC 8414]: https://tools.ietf.org/html/rfc8414) +type AuthServerMeta struct { + // GENERATED BY GEMINI 2.5. + + // Issuer is the REQUIRED URL identifying the authorization server. + Issuer string `json:"issuer"` + + // AuthorizationEndpoint is the REQUIRED URL of the server's OAuth 2.0 authorization endpoint. + AuthorizationEndpoint string `json:"authorization_endpoint"` + + // TokenEndpoint is the REQUIRED URL of the server's OAuth 2.0 token endpoint. + TokenEndpoint string `json:"token_endpoint"` + + // JWKSURI is the REQUIRED URL of the server's JSON Web Key Set [JWK] document. + JWKSURI string `json:"jwks_uri"` + + // RegistrationEndpoint is the RECOMMENDED URL of the server's OAuth 2.0 Dynamic Client Registration endpoint. + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + + // ScopesSupported is a RECOMMENDED JSON array of strings containing a list of the OAuth 2.0 + // "scope" values that this server supports. + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // ResponseTypesSupported is a REQUIRED JSON array of strings containing a list of the OAuth 2.0 + // "response_type" values that this server supports. + ResponseTypesSupported []string `json:"response_types_supported"` + + // ResponseModesSupported is a RECOMMENDED JSON array of strings containing a list of the OAuth 2.0 + // "response_mode" values that this server supports. + ResponseModesSupported []string `json:"response_modes_supported,omitempty"` + + // GrantTypesSupported is a RECOMMENDED JSON array of strings containing a list of the OAuth 2.0 + // grant type values that this server supports. + GrantTypesSupported []string `json:"grant_types_supported,omitempty"` + + // TokenEndpointAuthMethodsSupported is a RECOMMENDED JSON array of strings containing a list of + // client authentication methods supported by this token endpoint. + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"` + + // TokenEndpointAuthSigningAlgValuesSupported is a RECOMMENDED JSON array of strings containing + // a list of the JWS signing algorithms ("alg" values) supported by the token endpoint for + // the signature on the JWT used to authenticate the client. + TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"` + + // ServiceDocumentation is a RECOMMENDED URL of a page containing human-readable documentation + // for the service. + ServiceDocumentation string `json:"service_documentation,omitempty"` + + // UILocalesSupported is a RECOMMENDED JSON array of strings representing supported + // BCP47 [RFC5646] language tag values for display in the user interface. + UILocalesSupported []string `json:"ui_locales_supported,omitempty"` + + // OpPolicyURI is a RECOMMENDED URL that the server provides to the person registering + // the client to read about the server's operator policies. + OpPolicyURI string `json:"op_policy_uri,omitempty"` + + // OpTOSURI is a RECOMMENDED URL that the server provides to the person registering the + // client to read about the server's terms of service. + OpTOSURI string `json:"op_tos_uri,omitempty"` + + // RevocationEndpoint is a RECOMMENDED URL of the server's OAuth 2.0 revocation endpoint. + RevocationEndpoint string `json:"revocation_endpoint,omitempty"` + + // RevocationEndpointAuthMethodsSupported is a RECOMMENDED JSON array of strings containing + // a list of client authentication methods supported by this revocation endpoint. + RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported,omitempty"` + + // RevocationEndpointAuthSigningAlgValuesSupported is a RECOMMENDED JSON array of strings + // containing a list of the JWS signing algorithms ("alg" values) supported by the revocation + // endpoint for the signature on the JWT used to authenticate the client. + RevocationEndpointAuthSigningAlgValuesSupported []string `json:"revocation_endpoint_auth_signing_alg_values_supported,omitempty"` + + // IntrospectionEndpoint is a RECOMMENDED URL of the server's OAuth 2.0 introspection endpoint. + IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"` + + // IntrospectionEndpointAuthMethodsSupported is a RECOMMENDED JSON array of strings containing + // a list of client authentication methods supported by this introspection endpoint. + IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported,omitempty"` + + // IntrospectionEndpointAuthSigningAlgValuesSupported is a RECOMMENDED JSON array of strings + // containing a list of the JWS signing algorithms ("alg" values) supported by the introspection + // endpoint for the signature on the JWT used to authenticate the client. + IntrospectionEndpointAuthSigningAlgValuesSupported []string `json:"introspection_endpoint_auth_signing_alg_values_supported,omitempty"` + + // CodeChallengeMethodsSupported is a RECOMMENDED JSON array of strings containing a list of + // PKCE code challenge methods supported by this authorization server. + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported,omitempty"` +} + +var wellKnownPaths = []string{ + "/.well-known/oauth-authorization-server", + "/.well-known/openid-configuration", +} + +// GetAuthServerMeta issues a GET request to retrieve authorization server metadata +// from an OAuth authorization server with the given issuerURL. +// +// It follows [RFC 8414]: +// - The well-known paths specified there are inserted into the URL's path, one at time. +// The first to succeed is used. +// - The Issuer field is checked against issuerURL. +// +// [RFC 8414]: https://tools.ietf.org/html/rfc8414 +func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (*AuthServerMeta, error) { + var errs []error + for _, p := range wellKnownPaths { + u, err := prependToPath(issuerURL, p) + if err != nil { + // issuerURL is bad; no point in continuing. + return nil, err + } + asm, err := getJSON[AuthServerMeta](ctx, c, u, 1<<20) + if err == nil { + if asm.Issuer != issuerURL { // section 3.3 + // Security violation; don't keep trying. + return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuerURL) + } + return asm, nil + } + errs = append(errs, err) + } + return nil, fmt.Errorf("failed to get auth server metadata from %q: %w", issuerURL, errors.Join(errs...)) +} diff --git a/internal/oauthex/auth_meta_test.go b/internal/oauthex/auth_meta_test.go new file mode 100644 index 00000000..724c78f9 --- /dev/null +++ b/internal/oauthex/auth_meta_test.go @@ -0,0 +1,28 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package oauthex + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestAuthMetaParse(t *testing.T) { + // Verify that we parse Google's auth server metadata. + data, err := os.ReadFile(filepath.FromSlash("testdata/google-auth-meta.json")) + if err != nil { + t.Fatal(err) + } + var a AuthServerMeta + if err := json.Unmarshal(data, &a); err != nil { + t.Fatal(err) + } + // Spot check. + if g, w := a.Issuer, "https://accounts.google.com"; g != w { + t.Errorf("got %q, want %q", g, w) + } +} diff --git a/internal/oauthex/oauth2.go b/internal/oauthex/oauth2.go index d1166fe1..cf06ab39 100644 --- a/internal/oauthex/oauth2.go +++ b/internal/oauthex/oauth2.go @@ -4,3 +4,64 @@ // Package oauthex implements extensions to OAuth2. package oauthex + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" +) + +// prependToPath prepends pre to the path of urlStr. +// When pre is the well-known path, this is the algorithm specified in both RFC 9728 +// section 3.1 and RFC 8414 section 3.1. +func prependToPath(urlStr, pre string) (string, error) { + u, err := url.Parse(urlStr) + if err != nil { + return "", err + } + p := "/" + strings.Trim(pre, "/") + if u.Path != "" { + p += "/" + } + + u.Path = p + strings.TrimLeft(u.Path, "/") + return u.String(), nil +} + +// getJSON retrieves JSON and unmarshals JSON from the URL, as specified in both +// RFC 9728 and RFC 8414. +// It will not read more than limit bytes from the body. +func getJSON[T any](ctx context.Context, c *http.Client, url string, limit int64) (*T, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + if c == nil { + c = http.DefaultClient + } + res, err := c.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + // Specs require a 200. + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("bad status %s", res.Status) + } + // Specs require application/json. + if ct := res.Header.Get("Content-Type"); ct != "application/json" { + return nil, fmt.Errorf("bad content type %q", ct) + } + + var t T + dec := json.NewDecoder(io.LimitReader(res.Body, limit)) + if err := dec.Decode(&t); err != nil { + return nil, err + } + return &t, nil +} diff --git a/internal/oauthex/resource_meta.go b/internal/oauthex/resource_meta.go index eb981d2d..3818d860 100644 --- a/internal/oauthex/resource_meta.go +++ b/internal/oauthex/resource_meta.go @@ -9,7 +9,6 @@ package oauthex import ( "context" - "encoding/json" "errors" "fmt" "net/http" @@ -164,38 +163,15 @@ func getPRM(ctx context.Context, url string, c *http.Client, wantResource string if !strings.HasPrefix(strings.ToUpper(url), "HTTPS://") { return nil, fmt.Errorf("resource URL %q does not use HTTPS", url) } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + prm, err := getJSON[ProtectedResourceMetadata](ctx, c, url, 1<<20) if err != nil { return nil, err } - if c == nil { - c = http.DefaultClient - } - res, err := c.Do(req) - if err != nil { - return nil, err - } - defer res.Body.Close() - - // Spec §3.2 requires a 200. - if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("bad status %s", res.Status) - } - // Spec §3.2 requires application/json. - if ct := res.Header.Get("Content-Type"); ct != "application/json" { - return nil, fmt.Errorf("bad content type %q", ct) - } - - var prm ProtectedResourceMetadata - dec := json.NewDecoder(res.Body) - if err := dec.Decode(&prm); err != nil { - return nil, err - } // Validate the Resource field to thwart impersonation attacks (section 3.3). if prm.Resource != wantResource { return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, wantResource) } - return &prm, nil + return prm, nil } // challenge represents a single authentication challenge from a WWW-Authenticate header. diff --git a/internal/oauthex/testdata/google-auth-meta.json b/internal/oauthex/testdata/google-auth-meta.json new file mode 100644 index 00000000..258b7534 --- /dev/null +++ b/internal/oauthex/testdata/google-auth-meta.json @@ -0,0 +1,57 @@ +{ + "issuer": "https://accounts.google.com", + "authorization_endpoint": "https://accounts.google.com/o/oauth2/v2/auth", + "device_authorization_endpoint": "https://oauth2.googleapis.com/device/code", + "token_endpoint": "https://oauth2.googleapis.com/token", + "userinfo_endpoint": "https://openidconnect.googleapis.com/v1/userinfo", + "revocation_endpoint": "https://oauth2.googleapis.com/revoke", + "jwks_uri": "https://www.googleapis.com/oauth2/v3/certs", + "response_types_supported": [ + "code", + "token", + "id_token", + "code token", + "code id_token", + "token id_token", + "code token id_token", + "none" + ], + "subject_types_supported": [ + "public" + ], + "id_token_signing_alg_values_supported": [ + "RS256" + ], + "scopes_supported": [ + "openid", + "email", + "profile" + ], + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic" + ], + "claims_supported": [ + "aud", + "email", + "email_verified", + "exp", + "family_name", + "given_name", + "iat", + "iss", + "name", + "picture", + "sub" + ], + "code_challenge_methods_supported": [ + "plain", + "S256" + ], + "grant_types_supported": [ + "authorization_code", + "refresh_token", + "urn:ietf:params:oauth:grant-type:device_code", + "urn:ietf:params:oauth:grant-type:jwt-bearer" + ] +} From 46e767e646ad0c197bab8d7240aa7e1e03e87f4e Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Fri, 5 Sep 2025 16:30:59 -0400 Subject: [PATCH 178/221] Update README for v0.4.0 (#422) As we prepare to release v0.4.0, update the README to reference the new version. --- README.md | 4 ++-- internal/readme/README.src.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 606f579a..04885251 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ -# MCP Go SDK v0.3.0 +# MCP Go SDK v0.4.0 [![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/modelcontextprotocol/go-sdk) @@ -7,7 +7,7 @@ This version contains breaking changes. See the [release notes]( -https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.3.0) for details. +https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.4.0) for details. [![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index d128c935..1de4ebad 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -1,4 +1,4 @@ -# MCP Go SDK v0.3.0 +# MCP Go SDK v0.4.0 [![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/modelcontextprotocol/go-sdk) @@ -6,7 +6,7 @@ This version contains breaking changes. See the [release notes]( -https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.3.0) for details. +https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.4.0) for details. [![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) From 4656930de8ce967845da3e16c6afe31b260d66a4 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 5 Sep 2025 20:48:32 +0000 Subject: [PATCH 179/221] mcp: avoid "null" when tools fail to return content For some reason, this was a regression from v0.3.1, for some of our examples. Fix it and address the existing TODO. --- mcp/server.go | 10 +++++++--- mcp/streamable_test.go | 12 ++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 466c4459..af9c2ab4 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -496,9 +496,13 @@ func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolR Message: fmt.Sprintf("unknown tool %q", req.Params.Name), } } - // TODO: if handler returns nil content, it will serialize as null. - // Add a test and fix. - return st.handler(ctx, req) + res, err := st.handler(ctx, req) + if err == nil && res != nil && res.Content == nil { + res2 := *res + res2.Content = []Content{} // avoid "null" + res = &res2 + } + return res, err } func (s *Server) listResources(_ context.Context, req *ListResourcesRequest) (*ListResourcesResult, error) { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 1e7a63ce..a85fbec0 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -520,7 +520,7 @@ func TestStreamableServerTransport(t *testing.T) { method: "POST", messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{}, nil)}, + wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, }, }, }, @@ -547,14 +547,14 @@ func TestStreamableServerTransport(t *testing.T) { headers: http.Header{"Accept": {"text/plain", "*/*"}}, messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, + wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{Content: []Content{}}, nil)}, }, { method: "POST", headers: http.Header{"Accept": {"text/*, application/*"}}, messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, + wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{Content: []Content{}}, nil)}, }, }, }, @@ -592,7 +592,7 @@ func TestStreamableServerTransport(t *testing.T) { wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{ req(0, "notifications/progress", &ProgressNotificationParams{}), - resp(2, &CallToolResult{}, nil), + resp(2, &CallToolResult{Content: []Content{}}, nil), }, }, }, @@ -624,7 +624,7 @@ func TestStreamableServerTransport(t *testing.T) { wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{ req(1, "roots/list", &ListRootsParams{}), - resp(2, &CallToolResult{}, nil), + resp(2, &CallToolResult{Content: []Content{}}, nil), }, }, }, @@ -674,7 +674,7 @@ func TestStreamableServerTransport(t *testing.T) { }, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{ - resp(2, &CallToolResult{}, nil), + resp(2, &CallToolResult{Content: []Content{}}, nil), }, }, { From e06cc69dc42064c823e145c827a226da712922ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Flc=E3=82=9B?= Date: Mon, 8 Sep 2025 14:11:31 +0800 Subject: [PATCH 180/221] fix(docs): correct number in README (#427) --- README.md | 2 +- internal/readme/README.src.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 04885251..d8bad49e 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ software development kit (SDK) for the Model Context Protocol (MCP). ## Package documentation -The SDK consists of two importable packages: +The SDK consists of three importable packages: - The [`github.com/modelcontextprotocol/go-sdk/mcp`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp) diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index 1de4ebad..5ac7cb8d 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -20,7 +20,7 @@ software development kit (SDK) for the Model Context Protocol (MCP). ## Package documentation -The SDK consists of two importable packages: +The SDK consists of three importable packages: - The [`github.com/modelcontextprotocol/go-sdk/mcp`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp) From bae3558d49b590d56e21c4987ef8154afc1fc755 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20RUFFINONI?= Date: Mon, 8 Sep 2025 19:28:32 +0200 Subject: [PATCH 181/221] mcp: allow configurable terminate duration for CommandTransport (#363) Make the process termination timeout configurable instead of using a hardcoded 5-second delay. This allows applications to customize the termination behavior based on their specific needs. Fixes #322. --- mcp/cmd.go | 23 +++++++++++++---- mcp/cmd_test.go | 68 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 5 deletions(-) diff --git a/mcp/cmd.go b/mcp/cmd.go index 5ec8c9e7..01195d22 100644 --- a/mcp/cmd.go +++ b/mcp/cmd.go @@ -13,10 +13,18 @@ import ( "time" ) +const ( + defaultTerminateDuration = 5 * time.Second +) + // A CommandTransport is a [Transport] that runs a command and communicates // with it over stdin/stdout, using newline-delimited JSON. type CommandTransport struct { Command *exec.Cmd + // TerminateDuration controls how long Close waits after closing stdin + // for the process to exit before sending SIGTERM. + // If zero or negative, the default of 5s is used. + TerminateDuration time.Duration } // NewCommandTransport returns a [CommandTransport] that runs the given command @@ -46,15 +54,20 @@ func (t *CommandTransport) Connect(ctx context.Context) (Connection, error) { if err := t.Command.Start(); err != nil { return nil, err } - return newIOConn(&pipeRWC{t.Command, stdout, stdin}), nil + td := t.TerminateDuration + if td <= 0 { + td = defaultTerminateDuration + } + return newIOConn(&pipeRWC{t.Command, stdout, stdin, td}), nil } // A pipeRWC is an io.ReadWriteCloser that communicates with a subprocess over // stdin/stdout pipes. type pipeRWC struct { - cmd *exec.Cmd - stdout io.ReadCloser - stdin io.WriteCloser + cmd *exec.Cmd + stdout io.ReadCloser + stdin io.WriteCloser + terminateDuration time.Duration } func (s *pipeRWC) Read(p []byte) (n int, err error) { @@ -85,7 +98,7 @@ func (s *pipeRWC) Close() error { select { case err := <-resChan: return err, true - case <-time.After(5 * time.Second): + case <-time.After(s.terminateDuration): } return nil, false } diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index 98354a93..146cbe1f 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -251,6 +251,74 @@ func createServerCommand(t *testing.T, serverName string) *exec.Cmd { return cmd } +func TestCommandTransportTerminateDuration(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("requires POSIX signals") + } + requireExec(t) + + tests := []struct { + name string + duration time.Duration + wantMaxDuration time.Duration + }{ + { + name: "default duration (zero)", + duration: 0, + wantMaxDuration: 6 * time.Second, // default 5s + buffer + }, + { + name: "below minimum duration", + duration: 500 * time.Millisecond, + wantMaxDuration: 6 * time.Second, // should use default 5s + buffer + }, + { + name: "custom valid duration", + duration: 2 * time.Second, + wantMaxDuration: 3 * time.Second, // custom 2s + buffer + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Use a command that won't exit when stdin is closed + cmd := exec.Command("sleep", "20") + transport := &mcp.CommandTransport{ + Command: cmd, + TerminateDuration: tt.duration, + } + + conn, err := transport.Connect(ctx) + if err != nil { + t.Fatal(err) + } + + start := time.Now() + err = conn.Close() + elapsed := time.Since(start) + + if err != nil { + var exitErr *exec.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("Close() failed with unexpected error: %v", err) + } + } + + if elapsed > tt.wantMaxDuration { + t.Errorf("Close() took %v, expected at most %v", elapsed, tt.wantMaxDuration) + } + + // Ensure the process was actually terminated + if cmd.Process != nil { + cmd.Process.Kill() + } + }) + } +} + func requireExec(t *testing.T) { t.Helper() From 4f197bc3b401a5c34e3fc25242b79ff893a9ee47 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Mon, 8 Sep 2025 13:36:50 -0400 Subject: [PATCH 182/221] internal/oauthex: fix license headers (#430) The files in the oauthex package used the default Go license header, rather than the MCP SDK license header (an accident of tooling). We should fix this while jba@google.com is the only author of this code. A subsequent PR will add a lint check to prevent this in the future. --- internal/oauthex/auth_meta.go | 4 ++-- internal/oauthex/auth_meta_test.go | 4 ++-- internal/oauthex/oauth2.go | 4 ++-- internal/oauthex/oauth2_test.go | 4 ++-- internal/oauthex/resource_meta.go | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/internal/oauthex/auth_meta.go b/internal/oauthex/auth_meta.go index 1d6deb90..1f075f8a 100644 --- a/internal/oauthex/auth_meta.go +++ b/internal/oauthex/auth_meta.go @@ -1,5 +1,5 @@ -// Copyright 2025 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. // This file implements Authorization Server Metadata. diff --git a/internal/oauthex/auth_meta_test.go b/internal/oauthex/auth_meta_test.go index 724c78f9..b83402f2 100644 --- a/internal/oauthex/auth_meta_test.go +++ b/internal/oauthex/auth_meta_test.go @@ -1,5 +1,5 @@ -// Copyright 2025 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. package oauthex diff --git a/internal/oauthex/oauth2.go b/internal/oauthex/oauth2.go index cf06ab39..de164499 100644 --- a/internal/oauthex/oauth2.go +++ b/internal/oauthex/oauth2.go @@ -1,5 +1,5 @@ -// Copyright 2025 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. // Package oauthex implements extensions to OAuth2. diff --git a/internal/oauthex/oauth2_test.go b/internal/oauthex/oauth2_test.go index 92017f81..9c3da156 100644 --- a/internal/oauthex/oauth2_test.go +++ b/internal/oauthex/oauth2_test.go @@ -1,5 +1,5 @@ -// Copyright 2025 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. package oauthex diff --git a/internal/oauthex/resource_meta.go b/internal/oauthex/resource_meta.go index 3818d860..71d52cde 100644 --- a/internal/oauthex/resource_meta.go +++ b/internal/oauthex/resource_meta.go @@ -1,5 +1,5 @@ -// Copyright 2025 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. // This file implements Protected Resource Metadata. From a082e0764db5268c3a443cc69f4a5f0bff7a1dcc Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 9 Sep 2025 10:34:38 -0400 Subject: [PATCH 183/221] go.mod: update to jsonschema v0.2.1 --- examples/server/auth-middleware/go.mod | 2 +- examples/server/auth-middleware/go.sum | 4 ++-- examples/server/rate-limiting/go.mod | 2 +- examples/server/rate-limiting/go.sum | 4 ++-- go.mod | 2 +- go.sum | 6 ++---- 6 files changed, 9 insertions(+), 11 deletions(-) diff --git a/examples/server/auth-middleware/go.mod b/examples/server/auth-middleware/go.mod index 402c0aae..f1b65d2b 100644 --- a/examples/server/auth-middleware/go.mod +++ b/examples/server/auth-middleware/go.mod @@ -8,7 +8,7 @@ require ( ) require ( - github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76 // indirect + github.com/google/jsonschema-go v0.2.1 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect ) diff --git a/examples/server/auth-middleware/go.sum b/examples/server/auth-middleware/go.sum index eea5bdb5..ada94c0c 100644 --- a/examples/server/auth-middleware/go.sum +++ b/examples/server/auth-middleware/go.sum @@ -2,8 +2,8 @@ github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeD github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76 h1:mBlBwtDebdDYr+zdop8N62a44g+Nbv7o2KjWyS1deR4= -github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/jsonschema-go v0.2.1 h1:Z3iINWAUmvS4+m9cMP5lWbn6WlX8Hy4rpUS4pULVliQ= +github.com/google/jsonschema-go v0.2.1/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= diff --git a/examples/server/rate-limiting/go.mod b/examples/server/rate-limiting/go.mod index 2b4b00dd..8dd2f808 100644 --- a/examples/server/rate-limiting/go.mod +++ b/examples/server/rate-limiting/go.mod @@ -8,7 +8,7 @@ require ( ) require ( - github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76 // indirect + github.com/google/jsonschema-go v0.2.1 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect ) diff --git a/examples/server/rate-limiting/go.sum b/examples/server/rate-limiting/go.sum index 16b70ef7..9e6a390e 100644 --- a/examples/server/rate-limiting/go.sum +++ b/examples/server/rate-limiting/go.sum @@ -1,7 +1,7 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76 h1:mBlBwtDebdDYr+zdop8N62a44g+Nbv7o2KjWyS1deR4= -github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/jsonschema-go v0.2.1 h1:Z3iINWAUmvS4+m9cMP5lWbn6WlX8Hy4rpUS4pULVliQ= +github.com/google/jsonschema-go v0.2.1/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= diff --git a/go.mod b/go.mod index ebcdc591..fa4d2c3d 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.23.0 require ( github.com/google/go-cmp v0.7.0 - github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76 + github.com/google/jsonschema-go v0.2.1 github.com/yosida95/uritemplate/v3 v3.0.2 golang.org/x/tools v0.34.0 ) diff --git a/go.sum b/go.sum index 9ae7018a..7ed9f147 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,7 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/jsonschema-go v0.2.0 h1:Uh19091iHC56//WOsAd1oRg6yy1P9BpSvpjOL6RcjLQ= -github.com/google/jsonschema-go v0.2.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= -github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76 h1:mBlBwtDebdDYr+zdop8N62a44g+Nbv7o2KjWyS1deR4= -github.com/google/jsonschema-go v0.2.1-0.20250825175020-748c325cec76/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/jsonschema-go v0.2.1 h1:Z3iINWAUmvS4+m9cMP5lWbn6WlX8Hy4rpUS4pULVliQ= +github.com/google/jsonschema-go v0.2.1/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= From 5bbca40032e75a02c7a83a4af3a8299d0c88dc7b Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 9 Sep 2025 14:37:03 +0000 Subject: [PATCH 184/221] mcp: fix typo in protocolVersion20241105 This version was so last year. Fixes #432 --- mcp/shared.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mcp/shared.go b/mcp/shared.go index 29903926..69b8836a 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -36,13 +36,13 @@ const ( latestProtocolVersion = protocolVersion20250618 protocolVersion20250618 = "2025-06-18" protocolVersion20250326 = "2025-03-26" - protocolVersion20251105 = "2024-11-05" + protocolVersion20241105 = "2024-11-05" ) var supportedProtocolVersions = []string{ protocolVersion20250618, protocolVersion20250326, - protocolVersion20251105, + protocolVersion20241105, } // negotiatedVersion returns the effective protocol version to use, given a From b1c75f04d895b9ca4afafbf388e82a15de31f441 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 10 Sep 2025 13:36:19 +0000 Subject: [PATCH 185/221] mcp: remove deprecated transport constructors In #272, we made transports open structs and deprecated their constructors. However, we left the constructors in place to allow go:fix directives to facilitate migration. Now let's remove the constructors before cutting the release. Fixes #305 --- mcp/cmd.go | 13 ----------- mcp/sse.go | 37 ----------------------------- mcp/streamable.go | 53 ------------------------------------------ mcp/streamable_test.go | 2 +- mcp/transport.go | 20 ---------------- 5 files changed, 1 insertion(+), 124 deletions(-) diff --git a/mcp/cmd.go b/mcp/cmd.go index 01195d22..55e5cca6 100644 --- a/mcp/cmd.go +++ b/mcp/cmd.go @@ -27,19 +27,6 @@ type CommandTransport struct { TerminateDuration time.Duration } -// NewCommandTransport returns a [CommandTransport] that runs the given command -// and communicates with it over stdin/stdout. -// -// The resulting transport takes ownership of the command, starting it during -// [CommandTransport.Connect], and stopping it when the connection is closed. -// -// Deprecated: use a CommandTransport literal. -// -//go:fix inline -func NewCommandTransport(cmd *exec.Cmd) *CommandTransport { - return &CommandTransport{Command: cmd} -} - // Connect starts the command, and connects to it over stdin/stdout. func (t *CommandTransport) Connect(ctx context.Context) (Connection, error) { stdout, err := t.Command.StdoutPipe() diff --git a/mcp/sse.go b/mcp/sse.go index b7f0d4e2..f39a0397 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -114,19 +114,6 @@ type SSEServerTransport struct { done chan struct{} // closed when the connection is closed } -// NewSSEServerTransport creates a new SSE transport for the given messages -// endpoint, and hanging GET response. -// -// Deprecated: use an SSEServerTransport literal. -// -//go:fix inline -func NewSSEServerTransport(endpoint string, w http.ResponseWriter) *SSEServerTransport { - return &SSEServerTransport{ - Endpoint: endpoint, - Response: w, - } -} - // ServeHTTP handles POST requests to the transport endpoint. func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { if t.incoming == nil { @@ -334,30 +321,6 @@ type SSEClientTransport struct { HTTPClient *http.Client } -// SSEClientTransportOptions provides options for the [NewSSEClientTransport] -// constructor. -// -// Deprecated: use an SSEClientTransport literal. -type SSEClientTransportOptions struct { - // HTTPClient is the client to use for making HTTP requests. If nil, - // http.DefaultClient is used. - HTTPClient *http.Client -} - -// NewSSEClientTransport returns a new client transport that connects to the -// SSE server at the provided URL. -// -// Deprecated: use an SSEClientTransport literal. -// -//go:fix inline -func NewSSEClientTransport(endpoint string, opts *SSEClientTransportOptions) *SSEClientTransport { - t := &SSEClientTransport{Endpoint: endpoint} - if opts != nil { - t.HTTPClient = opts.HTTPClient - } - return t -} - // Connect connects through the client endpoint. func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { parsedURL, err := url.Parse(c.Endpoint) diff --git a/mcp/streamable.go b/mcp/streamable.go index 1eef9a74..e820a2ba 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -323,15 +323,6 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque transport.ServeHTTP(w, req) } -// StreamableServerTransportOptions configures the stramable server transport. -// -// Deprecated: use a StreamableServerTransport literal. -type StreamableServerTransportOptions struct { - // Storage for events, to enable stream resumption. - // If nil, a [MemoryEventStore] with the default maximum size will be used. - EventStore EventStore -} - // A StreamableServerTransport implements the server side of the MCP streamable // transport. // @@ -385,22 +376,6 @@ type StreamableServerTransport struct { connection *streamableServerConn } -// NewStreamableServerTransport returns a new [StreamableServerTransport] with -// the given session ID and options. -// -// Deprecated: use a StreamableServerTransport literal. -// -//go:fix inline. -func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransportOptions) *StreamableServerTransport { - t := &StreamableServerTransport{ - SessionID: sessionID, - } - if opts != nil { - t.EventStore = opts.EventStore - } - return t -} - // Connect implements the [Transport] interface. func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, error) { if t.connection != nil { @@ -1025,34 +1000,6 @@ const ( reconnectMaxDelay = 30 * time.Second ) -// StreamableClientTransportOptions provides options for the -// [NewStreamableClientTransport] constructor. -// -// Deprecated: use a StremableClientTransport literal. -type StreamableClientTransportOptions struct { - // HTTPClient is the client to use for making HTTP requests. If nil, - // http.DefaultClient is used. - HTTPClient *http.Client - // MaxRetries is the maximum number of times to attempt a reconnect before giving up. - // It defaults to 5. To disable retries, use a negative number. - MaxRetries int -} - -// NewStreamableClientTransport returns a new client transport that connects to -// the streamable HTTP server at the provided URL. -// -// Deprecated: use a StreamableClientTransport literal. -// -//go:fix inline -func NewStreamableClientTransport(url string, opts *StreamableClientTransportOptions) *StreamableClientTransport { - t := &StreamableClientTransport{Endpoint: url} - if opts != nil { - t.HTTPClient = opts.HTTPClient - t.MaxRetries = opts.MaxRetries - } - return t -} - // Connect implements the [Transport] interface. // // The resulting [Connection] writes messages via POST requests to the diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index a85fbec0..5c9685e6 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1211,7 +1211,7 @@ func TestTokenInfo(t *testing.T) { httpServer := httptest.NewServer(handler) defer httpServer.Close() - transport := NewStreamableClientTransport(httpServer.URL, nil) + transport := &StreamableClientTransport{Endpoint: httpServer.URL} client := NewClient(testImpl, nil) session, err := client.Connect(ctx, transport, nil) if err != nil { diff --git a/mcp/transport.go b/mcp/transport.go index 5c7ca130..608247cd 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -93,16 +93,6 @@ func (*StdioTransport) Connect(context.Context) (Connection, error) { return newIOConn(rwc{os.Stdin, os.Stdout}), nil } -// NewStdioTransport constructs a transport that communicates over -// stdin/stdout. -// -// Deprecated: use a StdioTransport literal. -// -//go:fix inline -func NewStdioTransport() *StdioTransport { - return &StdioTransport{} -} - // An InMemoryTransport is a [Transport] that communicates over an in-memory // network connection, using newline-delimited JSON. type InMemoryTransport struct { @@ -215,16 +205,6 @@ type LoggingTransport struct { Writer io.Writer } -// NewLoggingTransport creates a new LoggingTransport that delegates to the -// provided transport, writing RPC logs to the provided io.Writer. -// -// Deprecated: use a LoggingTransport literal. -// -//go:fix inline -func NewLoggingTransport(delegate Transport, w io.Writer) *LoggingTransport { - return &LoggingTransport{Transport: delegate, Writer: w} -} - // Connect connects the underlying transport, returning a [Connection] that writes // logs to the configured destination. func (t *LoggingTransport) Connect(ctx context.Context) (Connection, error) { From 728e0e3ac967a0cbf94eca9bce7a32b7ce0d287d Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 9 Sep 2025 14:36:19 +0000 Subject: [PATCH 186/221] mcp: improve error messages from Wait for streamable clients Previously, the error message received from ClientSession.Wait would only report the closeErr, which would often be nil even if the client transport was broken. Wait should return the reason the session terminated, if abnormal. I'm not sure of the exact semantics of this, but surely returning nil is less useful than returning a meaningful non-nil error. We can refine our handling of errors once we have more feedback. Also add a test for client termination on HTTP server shutdown, described in #265. This should work as long as (1) the session is stateful (with a hanging GET), or (2) the session is stateless but the client has a keepalive ping. Also: don't send DELETE if the session was terminated with 404; +test. Fixes #265 --- internal/jsonrpc2/conn.go | 27 +++++++++++-- mcp/streamable.go | 70 +++++++++++++++++++++++++-------- mcp/streamable_client_test.go | 52 +++++++++++++++++++++++- mcp/streamable_test.go | 74 +++++++++++++++++++++++++++++++++++ 4 files changed, 202 insertions(+), 21 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 49902b00..5549ee1c 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -483,10 +483,32 @@ func (c *Connection) Cancel(id ID) { // Wait blocks until the connection is fully closed, but does not close it. func (c *Connection) Wait() error { + return c.wait(true) +} + +// wait for the connection to close, and aggregates the most cause of its +// termination, if abnormal. +// +// The fromWait argument allows this logic to be shared with Close, where we +// only want to expose the closeErr. +// +// (Previously, Wait also only returned the closeErr, which was misleading if +// the connection was broken for another reason). +func (c *Connection) wait(fromWait bool) error { var err error <-c.done c.updateInFlight(func(s *inFlightState) { - err = s.closeErr + if fromWait { + if !errors.Is(s.readErr, io.EOF) { + err = s.readErr + } + if err == nil && !errors.Is(s.writeErr, io.EOF) { + err = s.writeErr + } + } + if err == nil { + err = s.closeErr + } }) return err } @@ -502,8 +524,7 @@ func (c *Connection) Close() error { // Stop handling new requests, and interrupt the reader (by closing the // connection) as soon as the active requests finish. c.updateInFlight(func(s *inFlightState) { s.connClosing = true }) - - return c.Wait() + return c.wait(false) } // readIncoming collects inbound messages from the reader and delivers them, either responding diff --git a/mcp/streamable.go b/mcp/streamable.go index e820a2ba..1e35ca5b 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1066,6 +1066,17 @@ type streamableClientConn struct { sessionID string } +// errSessionMissing distinguishes if the session is known to not be present on +// the server (see [streamableClientConn.fail]). +// +// TODO(rfindley): should we expose this error value (and its corresponding +// API) to the user? +// +// The spec says that if the server returns 404, clients should reestablish +// a session. For now, we delegate that to the user, but do they need a way to +// differentiate a 'NotFound' error from other errors? +var errSessionMissing = errors.New("session not found") + var _ clientConnection = (*streamableClientConn)(nil) func (c *streamableClientConn) sessionUpdated(state clientSessionState) { @@ -1093,6 +1104,10 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { // // If err is non-nil, it is terminal, and subsequent (or pending) Reads will // fail. +// +// If err wraps errSessionMissing, the failure indicates that the session is no +// longer present on the server, and no final DELETE will be performed when +// closing the connection. func (c *streamableClientConn) fail(err error) { if err != nil { c.failOnce.Do(func() { @@ -1140,9 +1155,19 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return err } + var requestSummary string + switch msg := msg.(type) { + case *jsonrpc.Request: + requestSummary = fmt.Sprintf("sending %q", msg.Method) + case *jsonrpc.Response: + requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) + default: + panic("unreachable") + } + data, err := jsonrpc.EncodeMessage(msg) if err != nil { - return err + return fmt.Errorf("%s: %v", requestSummary, err) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data)) @@ -1155,9 +1180,21 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e resp, err := c.client.Do(req) if err != nil { + return fmt.Errorf("%s: %v", requestSummary, err) + } + + // Section 2.5.3: "The server MAY terminate the session at any time, after + // which it MUST respond to requests containing that session ID with HTTP + // 404 Not Found." + if resp.StatusCode == http.StatusNotFound { + // Fail the session immediately, rather than relying on jsonrpc2 to fail + // (and close) it, because we want the call to Close to know that this + // session is missing (and therefore not send the DELETE). + err := fmt.Errorf("%s: failed to send: %w", requestSummary, errSessionMissing) + c.fail(err) + resp.Body.Close() return err } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { resp.Body.Close() return fmt.Errorf("broken session: %v", resp.Status) @@ -1180,16 +1217,6 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } - var requestSummary string - switch msg := msg.(type) { - case *jsonrpc.Request: - requestSummary = fmt.Sprintf("sending %q", msg.Method) - case *jsonrpc.Response: - requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) - default: - panic("unreachable") - } - switch ct := resp.Header.Get("Content-Type"); ct { case "application/json": go c.handleJSON(requestSummary, resp) @@ -1280,6 +1307,11 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt resp.Body.Close() return } + // (see equivalent handling in [streamableClientConn.Write]). + if resp.StatusCode == http.StatusNotFound { + c.fail(fmt.Errorf("%s: failed to reconnect: %w", requestSummary, errSessionMissing)) + return + } if resp.StatusCode < 200 || resp.StatusCode >= 300 { resp.Body.Close() c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, http.StatusText(resp.StatusCode))) @@ -1370,13 +1402,17 @@ func (c *streamableClientConn) Close() error { c.cancel() close(c.done) - req, err := http.NewRequest(http.MethodDelete, c.url, nil) - if err != nil { - c.closeErr = err + if errors.Is(c.failure(), errSessionMissing) { + // If the session is missing, no need to delete it. } else { - c.setMCPHeaders(req) - if _, err := c.client.Do(req); err != nil { + req, err := http.NewRequest(http.MethodDelete, c.url, nil) + if err != nil { c.closeErr = err + } else { + c.setMCPHeaders(req) + if _, err := c.client.Do(req); err != nil { + c.closeErr = err + } } } }) diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index fe87b21c..001d3a64 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -29,7 +29,7 @@ type streamableRequestKey struct { type header map[string]string type streamableResponse struct { - header header + header header // response headers status int // or http.StatusOK body string // or "" optional bool // if set, request need not be sent @@ -187,6 +187,56 @@ func TestStreamableClientTransportLifecycle(t *testing.T) { } } +func TestStreamableClientRedundantDelete(t *testing.T) { + ctx := context.Background() + + // The lifecycle test verifies various behavior of the streamable client + // initialization: + // - check that it can handle application/json responses + // - check that it sends the negotiated protocol version + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", ""}: { + status: http.StatusMethodNotAllowed, + optional: true, + }, + {"POST", "123", methodListTools}: { + status: http.StatusNotFound, + }, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + _, err = session.ListTools(ctx, nil) + if err == nil { + t.Errorf("Listing tools: got nil error, want non-nil") + } + _ = session.Wait() // must not hang + if missing := fake.missingRequests(); len(missing) > 0 { + t.Errorf("did not receive expected requests: %v", missing) + } +} + func TestStreamableClientGETHandling(t *testing.T) { ctx := context.Background() diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 5c9685e6..f731a083 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -190,6 +190,80 @@ func TestStreamableTransports(t *testing.T) { } } +func TestStreamableServerShutdown(t *testing.T) { + ctx := context.Background() + + // This test checks that closing the streamable HTTP server actually results + // in client session termination, provided one of following holds: + // 1. The server is stateful, and therefore the hanging GET fails the connection. + // 2. The server is stateless, and the client uses a KeepAlive. + tests := []struct { + name string + stateless, keepalive bool + }{ + {"stateful", false, false}, + {"stateless with keepalive", true, true}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + server := NewServer(testImpl, nil) + // Add a tool, just so we can check things are working. + AddTool(server, &Tool{Name: "greet"}, sayHi) + + handler := NewStreamableHTTPHandler( + func(req *http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: test.stateless}) + + // When we shut down the server, we need to explicitly close ongoing + // connections. Otherwise, the hanging GET may never terminate. + httpServer := httptest.NewUnstartedServer(handler) + httpServer.Config.RegisterOnShutdown(func() { + for session := range server.Sessions() { + session.Close() + } + }) + httpServer.Start() + defer httpServer.Close() + + // Connect and run a tool. + var opts ClientOptions + if test.keepalive { + opts.KeepAlive = 50 * time.Millisecond + } + client := NewClient(testImpl, &opts) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{ + Endpoint: httpServer.URL, + MaxRetries: -1, // avoid slow tests during exponential retries + }, nil) + if err != nil { + t.Fatal(err) + } + + params := &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"name": "foo"}, + } + // Verify that we can call a tool. + if _, err := clientSession.CallTool(ctx, params); err != nil { + t.Fatalf("CallTool() failed: %v", err) + } + + // Shut down the server. Sessions should terminate. + go func() { + if err := httpServer.Config.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Errorf("closing http server: %v", err) + } + }() + + // Wait may return an error (after all, the connection failed), but it + // should not hang. + t.Log("Client waiting") + _ = clientSession.Wait() + }) + } +} + // TestClientReplay verifies that the client can recover from a mid-stream // network failure and receive replayed messages (if replay is configured). It // uses a proxy that is killed and restarted to simulate a recoverable network From 3a66cf399735f4185b494bad41c340b813580005 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Thu, 11 Sep 2025 10:21:02 -0400 Subject: [PATCH 187/221] go.mod: update jsonschema to v0.2.2 (#446) This version of jsonschema fixes embedded structs. Fixes: #437 --- go.mod | 2 +- go.sum | 4 ++-- mcp/mcp_test.go | 60 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index fa4d2c3d..0e18643d 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.23.0 require ( github.com/google/go-cmp v0.7.0 - github.com/google/jsonschema-go v0.2.1 + github.com/google/jsonschema-go v0.2.2 github.com/yosida95/uritemplate/v3 v3.0.2 golang.org/x/tools v0.34.0 ) diff --git a/go.sum b/go.sum index 7ed9f147..169c67ed 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/jsonschema-go v0.2.1 h1:Z3iINWAUmvS4+m9cMP5lWbn6WlX8Hy4rpUS4pULVliQ= -github.com/google/jsonschema-go v0.2.1/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/jsonschema-go v0.2.2 h1:qb9KM/pATIqIPuE9gEDwPsco8HHCTlA88IGFYHDl03A= +github.com/google/jsonschema-go v0.2.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index aee06a31..4b05ce7d 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -1777,3 +1777,63 @@ func TestComplete(t *testing.T) { t.Errorf("Complete() mismatch (-want +got):\n%s", diff) } } + +// TestEmbeddedStructResponse performs a tool call to verify that a struct with +// an embedded pointer generates a correct, flattened JSON schema and that its +// response is validated successfully. +func TestEmbeddedStructResponse(t *testing.T) { + type foo struct { + ID string `json:"id"` + Name string `json:"name"` + } + + // bar embeds foo + type bar struct { + *foo // Embedded - should flatten in JSON + Extra string `json:"extra"` + } + + type response struct { + Data bar `json:"data"` + } + + // testTool demonstrates an embedded struct in its response. + testTool := func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, response, error) { + response := response{ + Data: bar{ + foo: &foo{ + ID: "foo", + Name: "Test Foo", + }, + Extra: "additional data", + }, + } + return nil, response, nil + } + ctx := context.Background() + clientTransport, serverTransport := NewInMemoryTransports() + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + AddTool(server, &Tool{ + Name: "test_embedded_struct", + }, testTool) + + serverSession, err := server.Connect(ctx, serverTransport, nil) + if err != nil { + t.Fatal(err) + } + defer serverSession.Close() + + client := NewClient(&Implementation{Name: "test-client"}, nil) + clientSession, err := client.Connect(ctx, clientTransport, nil) + if err != nil { + t.Fatal(err) + } + defer clientSession.Close() + + _, err = clientSession.CallTool(ctx, &CallToolParams{ + Name: "test_embedded_struct", + }) + if err != nil { + t.Errorf("CallTool() failed: %v", err) + } +} From 872b437b2a098abba63abab3e953b156de02087e Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 10 Sep 2025 19:25:22 +0000 Subject: [PATCH 188/221] internal/readme: use go:generate rather than Make Simplify the readme generation to just use go:generate, no longer depending on Make or bash. Subsequent CLs will generate more documentation in a similar manner. For #442 --- .github/workflows/readme-check.yml | 7 +++---- internal/readme/Makefile | 19 ------------------- internal/readme/build.sh | 22 ---------------------- internal/readme/doc.go | 9 +++++++++ mcp/mcp.go | 2 -- 5 files changed, 12 insertions(+), 47 deletions(-) delete mode 100644 internal/readme/Makefile delete mode 100755 internal/readme/build.sh create mode 100644 internal/readme/doc.go diff --git a/.github/workflows/readme-check.yml b/.github/workflows/readme-check.yml index b3398ead..bed3ff44 100644 --- a/.github/workflows/readme-check.yml +++ b/.github/workflows/readme-check.yml @@ -19,14 +19,13 @@ jobs: uses: actions/checkout@v4 - name: Check README is up-to-date run: | - cd internal/readme - make + go generate ./internal/readme if [ -n "$(git status --porcelain)" ]; then echo "ERROR: README.md is not up-to-date!" echo "" - echo "The README.md file differs from what would be generated by running 'make' in internal/readme/." + echo "The README.md file differs from what would be generated by `go generate ./internal/readme`." echo "Please update internal/readme/README.src.md instead of README.md directly," - echo "then run 'make' in the internal/readme/ directory to regenerate README.md." + echo "then run `go generate ./internal/readme` to regenerate README.md." echo "" echo "Changes:" git status --porcelain diff --git a/internal/readme/Makefile b/internal/readme/Makefile deleted file mode 100644 index 5e484184..00000000 --- a/internal/readme/Makefile +++ /dev/null @@ -1,19 +0,0 @@ -# This makefile builds ../README.md from the files in this directory. - -OUTFILE=../../README.md - -$(OUTFILE): build README.src.md - go run golang.org/x/example/internal/cmd/weave@latest README.src.md > $(OUTFILE) - -# Compile all the code used in the README. -build: $(wildcard */*.go) - go build -o /tmp/mcp-readme/ ./... - -# Preview the README on GitHub. -# $HOME/markdown must be a github repo. -# Visit https://github.com/$HOME/markdown to see the result. -preview: $(OUTFILE) - cp $(OUTFILE) $$HOME/markdown/ - (cd $$HOME/markdown/ && git commit -m . README.md && git push) - -.PHONY: build preview diff --git a/internal/readme/build.sh b/internal/readme/build.sh deleted file mode 100755 index b721d505..00000000 --- a/internal/readme/build.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/sh - -# Build README.md from the files in this directory. -# Must be invoked from the mcp directory. - -cd ../internal/readme - -outfile=../../README.md - -# Compile all the code used in the README. -go build -o /tmp/mcp-readme/ ./... -# Combine the code with the text in README.src.md. -# TODO: when at Go 1.24, use a tool directive for weave. -go run golang.org/x/example/internal/cmd/weave@latest README.src.md > $outfile - -if [[ $1 = '-preview' ]]; then - # Preview the README on GitHub. - # $HOME/markdown must be a github repo. - # Visit https://github.com/$HOME/markdown to see the result. - cp $outfile $HOME/markdown/ - (cd $HOME/markdown/ && git commit -m . README.md && git push) -fi diff --git a/internal/readme/doc.go b/internal/readme/doc.go new file mode 100644 index 00000000..34bc60c8 --- /dev/null +++ b/internal/readme/doc.go @@ -0,0 +1,9 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:generate go run golang.org/x/example/internal/cmd/weave@latest -o ../../README.md ./README.src.md + +// The readme package is used to generate README.md at the top-level of this +// repo. Regenerate the README with go generate. +package readme diff --git a/mcp/mcp.go b/mcp/mcp.go index 839f9199..88321a1e 100644 --- a/mcp/mcp.go +++ b/mcp/mcp.go @@ -2,8 +2,6 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. -//go:generate ../internal/readme/build.sh - // The mcp package provides an SDK for writing model context protocol clients // and servers. // From bf3ff50a1693a201bb5d1da49c505e4adf898101 Mon Sep 17 00:00:00 2001 From: ankitm123 Date: Thu, 11 Sep 2025 10:05:43 +0200 Subject: [PATCH 189/221] mcp: remove json-rpc batching for more recent protocol versions JSON-RPC batching support is removed for protocol versions greater than or equal to protocolVersion20250618 Fixes #21 --- mcp/streamable.go | 19 ++++++++++++------- mcp/streamable_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 1e35ca5b..bfaccae4 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -614,17 +614,22 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques http.Error(w, "POST requires a non-empty body", http.StatusBadRequest) return } - // TODO(#21): if the negotiated protocol version is 2025-06-18 or later, - // we should not allow batching here. - // - // This also requires access to the negotiated version, which would either be - // set by the MCP-Protocol-Version header, or would require peeking into the - // session. - incoming, _, err := readBatch(body) + incoming, isBatch, err := readBatch(body) if err != nil { http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) return } + + protocolVersion := req.Header.Get(protocolVersionHeader) + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + + if isBatch && protocolVersion >= protocolVersion20250618 { + http.Error(w, fmt.Sprintf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, protocolVersion), http.StatusBadRequest) + return + } + requests := make(map[jsonrpc.ID]struct{}) tokenInfo := auth.TokenInfoFromContext(req.Context()) isInitialize := false diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index f731a083..8817784d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -647,6 +647,46 @@ func TestStreamableServerTransport(t *testing.T) { }, }, }, + { + name: "batch rejected on 2025-06-18", + requests: []streamableRequest{ + initialize, + initialized, + { + method: "POST", + // Explicitly set the protocol version header + headers: http.Header{"MCP-Protocol-Version": {"2025-06-18"}}, + // Two messages => batch. Expect reject. + messages: []jsonrpc.Message{ + req(101, "tools/call", &CallToolParams{Name: "tool"}), + req(102, "tools/call", &CallToolParams{Name: "tool"}), + }, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "batch", + }, + }, + }, + { + name: "batch accepted on 2025-03-26", + requests: []streamableRequest{ + initialize, + initialized, + { + method: "POST", + headers: http.Header{"MCP-Protocol-Version": {"2025-03-26"}}, + // Two messages => batch. Expect OK with two responses in order. + messages: []jsonrpc.Message{ + req(201, "tools/call", &CallToolParams{Name: "tool"}), + req(202, "tools/call", &CallToolParams{Name: "tool"}), + }, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{ + resp(201, &CallToolResult{Content: []Content{}}, nil), + resp(202, &CallToolResult{Content: []Content{}}, nil), + }, + }, + }, + }, { name: "tool notification", tool: func(t *testing.T, ctx context.Context, ss *ServerSession) { From b66679350e07b21fcd49a7c9200461d778a591ef Mon Sep 17 00:00:00 2001 From: Sam Thanawalla <17936816+samthanawalla@users.noreply.github.com> Date: Thu, 11 Sep 2025 13:33:46 -0400 Subject: [PATCH 190/221] transport.go: disable stdio batching for newer protocols (#453) This PR disables batching support for stdio by storing the negotiated protocol version and comparing it to protocolVersion20250618. Fixes: #21 --- mcp/transport.go | 17 +++++++++++++++++ mcp/transport_test.go | 37 ++++++++++++++++++++++++++++++++----- 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/mcp/transport.go b/mcp/transport.go index 608247cd..024863de 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -283,6 +283,8 @@ func (r rwc) Close() error { // // See [msgBatch] for more discussion of message batching. type ioConn struct { + protocolVersion string // negotiated version, set during session initialization. + writeMu sync.Mutex // guards Write, which must be concurrency safe. rwc io.ReadWriteCloser // the underlying stream @@ -360,6 +362,17 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn { func (c *ioConn) SessionID() string { return "" } +func (c *ioConn) sessionUpdated(state ServerSessionState) { + protocolVersion := "" + if state.InitializeParams != nil { + protocolVersion = state.InitializeParams.ProtocolVersion + } + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + c.protocolVersion = negotiatedVersion(protocolVersion) +} + // addBatch records a msgBatch for an incoming batch payload. // It returns an error if batch is malformed, containing previously seen IDs. // @@ -458,6 +471,10 @@ func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { if err != nil { return nil, err } + if batch && t.protocolVersion >= protocolVersion20250618 { + return nil, fmt.Errorf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, t.protocolVersion) + } + t.queue = msgs[1:] if batch { diff --git a/mcp/transport_test.go b/mcp/transport_test.go index 18a326e8..d40ce10f 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -55,17 +55,16 @@ func TestBatchFraming(t *testing.T) { func TestIOConnRead(t *testing.T) { tests := []struct { - name string - input string - want string + name string + input string + want string + protocolVersion string }{ - { name: "valid json input", input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}}`, want: "", }, - { name: "newline at the end of first valid json input", input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}} @@ -77,13 +76,41 @@ func TestIOConnRead(t *testing.T) { input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}},`, want: "invalid trailing data at the end of stream", }, + { + name: "batching unknown protocol", + input: `[{"jsonrpc":"2.0","id":1,"method":"test1"},{"jsonrpc":"2.0","id":2,"method":"test2"}]`, + want: "", + protocolVersion: "", + }, + { + name: "batching old protocol", + input: `[{"jsonrpc":"2.0","id":1,"method":"test1"},{"jsonrpc":"2.0","id":2,"method":"test2"}]`, + want: "", + protocolVersion: protocolVersion20241105, + }, + { + name: "batching new protocol", + input: `[{"jsonrpc":"2.0","id":1,"method":"test1"},{"jsonrpc":"2.0","id":2,"method":"test2"}]`, + want: "JSON-RPC batching is not supported in 2025-06-18 and later (request version: 2025-06-18)", + protocolVersion: protocolVersion20250618, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tr := newIOConn(rwc{ rc: io.NopCloser(strings.NewReader(tt.input)), }) + if tt.protocolVersion != "" { + tr.sessionUpdated(ServerSessionState{ + InitializeParams: &InitializeParams{ + ProtocolVersion: tt.protocolVersion, + }, + }) + } _, err := tr.Read(context.Background()) + if err == nil && tt.want != "" { + t.Errorf("ioConn.Read() got nil error but wanted %v", tt.want) + } if err != nil && err.Error() != tt.want { t.Errorf("ioConn.Read() = %v, want %v", err.Error(), tt.want) } From eddef064df1dada305c6198d9adcd63f8da25d42 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Mon, 8 Sep 2025 18:32:15 +0000 Subject: [PATCH 191/221] all: add a test that checks for missing or mismatching file copyright In #294, we accidentally checked files with the default Go project copyright header. Avoid this in the future with a copyright test. --- copyright_test.go | 57 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 copyright_test.go diff --git a/copyright_test.go b/copyright_test.go new file mode 100644 index 00000000..287e5948 --- /dev/null +++ b/copyright_test.go @@ -0,0 +1,57 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package gosdk + +import ( + "fmt" + "go/parser" + "go/token" + "io/fs" + "path/filepath" + "regexp" + "strings" + "testing" +) + +func TestCopyrightHeaders(t *testing.T) { + var re = regexp.MustCompile(`Copyright \d{4} The Go MCP SDK Authors. All rights reserved. +Use of this source code is governed by an MIT-style +license that can be found in the LICENSE file.`) + + err := filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + // Skip directories starting with "." or "_", and testdata directories. + if d.IsDir() && d.Name() != "." && + (strings.HasPrefix(d.Name(), ".") || + strings.HasPrefix(d.Name(), "_") || + filepath.Base(d.Name()) == "testdata") { + + return filepath.SkipDir + } + + // Skip non-go files. + if !strings.HasSuffix(path, ".go") { + return nil + } + + // Check the copyright header. + f, err := parser.ParseFile(token.NewFileSet(), path, nil, parser.ParseComments|parser.PackageClauseOnly) + if err != nil { + return fmt.Errorf("parsing %s: %v", path, err) + } + if len(f.Comments) == 0 { + t.Errorf("File %s must start with a copyright header matching %s", path, re) + } else if !re.MatchString(f.Comments[0].Text()) { + t.Errorf("Header comment for %s does not match expected copyright header.\ngot:\n%s\nwant matching:%s", path, f.Comments[0].Text(), re) + } + return nil + }) + if err != nil { + t.Fatal(err) + } +} From 9f1b8646bdf6a3c004e273c93d031baac706c352 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 11 Sep 2025 19:48:09 +0000 Subject: [PATCH 192/221] mcp: propertly validate against JSON, independent of Go values Our validation logic was avoiding double-unmarshalling as much as possible, by parsing before validation and validating the Go type. This only works if the Go type has the same structure as its JSON representation, which may not be the case in the presence of types with custom MarshalJSON or UnmarshalJSON methods (such as time.Time). But even if the Go type doesn't use any custom marshalling, validation is broken, because we can't differentiate zero values from missing values. Bite the bullet and use double-unmarshalling for both input and output schemas. Coincidentally, this fixes three bugs: - We were accepting case-insensitive JSON keys, since we parsed first, even though they should have been rejected. A number of tests were wrong. - Defaults were overriding present-yet-zero values, as noted in an incorrect test case. - When "arguments" was missing, validation wasn't performed, no defaults were applied, and unmarshalling failed even if all properties were optional. First unmarshalling to map[string]any allows us to fix all these bugs. Unfortunately, it means a 3x increase in the number of reflection operations (we need to unmarshal, apply defaults and validate, re-marshal with the defaults, and then unmarshal into the Go type). However, this is not likely to be a significant overhead, and we can always optimize in the future. Update github.com/google/jsonschema-go to pick up necessary improvements supporting this change. Additionally, fix the error codes for invalid tool parameters, to be consistent with other SDKs (Invalid Params: -32602). Fixes #447 Fixes #449 Updates #450 --- go.mod | 2 +- go.sum | 4 +- mcp/conformance_test.go | 43 ++++++- mcp/mcp_test.go | 6 +- mcp/protocol.go | 7 ++ mcp/server.go | 45 ++++--- mcp/server_test.go | 53 +++++--- mcp/sse_example_test.go | 3 +- mcp/streamable_test.go | 5 +- mcp/testdata/conformance/server/tools.txtar | 127 +++++++++++++++++++- mcp/tool.go | 64 +++++----- mcp/tool_test.go | 13 +- 12 files changed, 287 insertions(+), 85 deletions(-) diff --git a/go.mod b/go.mod index 0e18643d..d303ef0c 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.23.0 require ( github.com/google/go-cmp v0.7.0 - github.com/google/jsonschema-go v0.2.2 + github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2 github.com/yosida95/uritemplate/v3 v3.0.2 golang.org/x/tools v0.34.0 ) diff --git a/go.sum b/go.sum index 169c67ed..6903b659 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/jsonschema-go v0.2.2 h1:qb9KM/pATIqIPuE9gEDwPsco8HHCTlA88IGFYHDl03A= -github.com/google/jsonschema-go v0.2.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2 h1:IIj7X4SH1HKy0WfPR4nNEj4dhIJWGdXM5YoBAbfpdoo= +github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= diff --git a/mcp/conformance_test.go b/mcp/conformance_test.go index a8da4fb7..3393efcb 100644 --- a/mcp/conformance_test.go +++ b/mcp/conformance_test.go @@ -20,9 +20,11 @@ import ( "strings" "testing" "testing/synctest" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" "golang.org/x/tools/txtar" @@ -97,16 +99,40 @@ func TestServerConformance(t *testing.T) { } } -type input struct { +type structuredInput struct { In string `jsonschema:"the input"` } -type output struct { +type structuredOutput struct { Out string `jsonschema:"the output"` } -func structuredTool(ctx context.Context, req *CallToolRequest, args *input) (*CallToolResult, *output, error) { - return nil, &output{"Ack " + args.In}, nil +func structuredTool(ctx context.Context, req *CallToolRequest, args *structuredInput) (*CallToolResult, *structuredOutput, error) { + return nil, &structuredOutput{"Ack " + args.In}, nil +} + +type tomorrowInput struct { + Now time.Time +} + +type tomorrowOutput struct { + Tomorrow time.Time +} + +func tomorrowTool(ctx context.Context, req *CallToolRequest, args tomorrowInput) (*CallToolResult, tomorrowOutput, error) { + return nil, tomorrowOutput{args.Now.Add(24 * time.Hour)}, nil +} + +type incInput struct { + X int `json:"x,omitempty"` +} + +type incOutput struct { + Y int `json:"y"` +} + +func incTool(_ context.Context, _ *CallToolRequest, args incInput) (*CallToolResult, incOutput, error) { + return nil, incOutput{args.X + 1}, nil } // runServerTest runs the server conformance test. @@ -124,6 +150,15 @@ func runServerTest(t *testing.T, test *conformanceTest) { }, sayHi) case "structured": AddTool(s, &Tool{Name: "structured"}, structuredTool) + case "tomorrow": + AddTool(s, &Tool{Name: "tomorrow"}, tomorrowTool) + case "inc": + inSchema, err := jsonschema.For[incInput](nil) + if err != nil { + t.Fatal(err) + } + inSchema.Properties["x"].Default = json.RawMessage(`6`) + AddTool(s, &Tool{Name: "inc", InputSchema: inSchema}, incTool) default: t.Fatalf("unknown tool %q", tn) } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 4b05ce7d..6191954c 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -224,7 +224,7 @@ func TestEndToEnd(t *testing.T) { // ListTools is tested in client_list_test.go. gotHi, err := cs.CallTool(ctx, &CallToolParams{ Name: "greet", - Arguments: map[string]any{"name": "user"}, + Arguments: map[string]any{"Name": "user"}, }) if err != nil { t.Fatal(err) @@ -648,7 +648,7 @@ func TestServerClosing(t *testing.T) { }() if _, err := cs.CallTool(ctx, &CallToolParams{ Name: "greet", - Arguments: map[string]any{"name": "user"}, + Arguments: map[string]any{"Name": "user"}, }); err != nil { t.Fatalf("after connecting: %v", err) } @@ -1646,7 +1646,7 @@ var testImpl = &Implementation{Name: "test", Version: "v1.0.0"} // If anyone asks, we can add an option that controls how pointers are treated. func TestPointerArgEquivalence(t *testing.T) { type input struct { - In string + In string `json:",omitempty"` } type output struct { Out string diff --git a/mcp/protocol.go b/mcp/protocol.go index a8e4817d..aeb9adbd 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -105,6 +105,13 @@ type CallToolResult struct { IsError bool `json:"isError,omitempty"` } +// TODO(#64): consider exposing setError (and getError), by adding an error +// field on CallToolResult. +func (r *CallToolResult) setError(err error) { + r.Content = []Content{&TextContent{Text: err.Error()}} + r.IsError = true +} + func (*CallToolResult) isResult() {} // UnmarshalJSON handles the unmarshalling of content into the Content diff --git a/mcp/server.go b/mcp/server.go index af9c2ab4..69808ac7 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -189,7 +189,6 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { func() bool { s.tools.add(st); return true }) } -// TODO(v0.3.0): test func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { tt := *t @@ -221,11 +220,23 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan } th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + var input json.RawMessage + if req.Params.Arguments != nil { + input = req.Params.Arguments + } + // Validate input and apply defaults. + var err error + input, err = applySchema(input, inputResolved) + if err != nil { + // TODO(#450): should this be considered a tool error? (and similar below) + return nil, fmt.Errorf("%w: validating \"arguments\": %v", jsonrpc2.ErrInvalidParams, err) + } + // Unmarshal and validate args. var in In - if req.Params.Arguments != nil { - if err := unmarshalSchema(req.Params.Arguments, inputResolved, &in); err != nil { - return nil, err + if input != nil { + if err := json.Unmarshal(input, &in); err != nil { + return nil, fmt.Errorf("%w: %v", jsonrpc2.ErrInvalidParams, err) } } @@ -241,22 +252,15 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan return nil, wireErr } // For regular errors, embed them in the tool result as per MCP spec - return &CallToolResult{ - Content: []Content{&TextContent{Text: err.Error()}}, - IsError: true, - }, nil - } - - // Validate output schema, if any. - // Skip if out is nil: we've removed "null" from the output schema, so nil won't validate. - if v := reflect.ValueOf(out); v.Kind() == reflect.Pointer && v.IsNil() { - } else if err := validateSchema(outputResolved, &out); err != nil { - return nil, fmt.Errorf("tool output: %w", err) + var errRes CallToolResult + errRes.setError(err) + return &errRes, nil } if res == nil { res = &CallToolResult{} } + // Marshal the output and put the RawMessage in the StructuredContent field. var outval any = out if elemZero != nil { @@ -272,7 +276,16 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan if err != nil { return nil, fmt.Errorf("marshaling output: %w", err) } - res.StructuredContent = json.RawMessage(outbytes) // avoid a second marshal over the wire + outJSON := json.RawMessage(outbytes) + // Validate the output JSON, and apply defaults. + // + // We validate against the JSON, rather than the output value, as + // some types may have custom JSON marshalling (issue #447). + outJSON, err = applySchema(outJSON, outputResolved) + if err != nil { + return nil, fmt.Errorf("validating tool output: %w", err) + } + res.StructuredContent = outJSON // avoid a second marshal over the wire // If the Content field isn't being used, return the serialized JSON in a // TextContent block, as the spec suggests: diff --git a/mcp/server_test.go b/mcp/server_test.go index e46be379..4456495f 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -9,6 +9,7 @@ import ( "encoding/json" "log" "slices" + "strings" "testing" "time" @@ -491,7 +492,7 @@ func TestAddTool(t *testing.T) { type schema = jsonschema.Schema -func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut *schema, wantErr bool) { +func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut *schema, wantErrContaining string) { t.Helper() th := func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) { return nil, out, nil @@ -513,34 +514,48 @@ func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out } _, err = goth(context.Background(), ctr) - if gotErr := err != nil; gotErr != wantErr { - t.Errorf("got error: %t, want error: %t", gotErr, wantErr) + if wantErrContaining != "" { + if err == nil { + t.Errorf("got nil error, want error containing %q", wantErrContaining) + } else { + if !strings.Contains(err.Error(), wantErrContaining) { + t.Errorf("got error %q, want containing %q", err, wantErrContaining) + } + } + } else if err != nil { + t.Errorf("got error %v, want no error", err) } } func TestToolForSchemas(t *testing.T) { - // Validate that ToolFor handles schemas properly. + // Validate that toolForErr handles schemas properly. + type in struct { + P int `json:"p,omitempty"` + } + type out struct { + B bool `json:"b,omitempty"` + } + + var ( + falseSchema = &schema{Not: &schema{}} + inSchema = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"p": {Type: "integer"}}} + inSchema2 = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"p": {Type: "string"}}} + outSchema = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"b": {Type: "boolean"}}} + outSchema2 = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"b": {Type: "integer"}}} + ) // Infer both schemas. - testToolForSchema[int](t, &Tool{}, "3", true, - &schema{Type: "integer"}, &schema{Type: "boolean"}, false) + testToolForSchema[in](t, &Tool{}, `{"p":3}`, out{true}, inSchema, outSchema, "") // Validate the input schema: expect an error if it's wrong. // We can't test that the output schema is validated, because it's typed. - testToolForSchema[int](t, &Tool{}, `"x"`, true, - &schema{Type: "integer"}, &schema{Type: "boolean"}, true) - + testToolForSchema[in](t, &Tool{}, `{"p":"x"}`, out{true}, inSchema, outSchema, `want "integer"`) // Ignore type any for output. - testToolForSchema[int, any](t, &Tool{}, "3", 0, - &schema{Type: "integer"}, nil, false) + testToolForSchema[in, any](t, &Tool{}, `{"p":3}`, 0, inSchema, nil, "") // Input is still validated. - testToolForSchema[int, any](t, &Tool{}, `"x"`, 0, - &schema{Type: "integer"}, nil, true) - + testToolForSchema[in, any](t, &Tool{}, `{"p":"x"}`, 0, inSchema, nil, `want "integer"`) // Tool sets input schema: that is what's used. - testToolForSchema[int, any](t, &Tool{InputSchema: &schema{Type: "string"}}, "3", 0, - &schema{Type: "string"}, nil, true) // error: 3 is not a string - + testToolForSchema[in, any](t, &Tool{InputSchema: inSchema2}, `{"p":3}`, 0, inSchema2, nil, `want "string"`) // Tool sets output schema: that is what's used, and validation happens. - testToolForSchema[string, any](t, &Tool{OutputSchema: &schema{Type: "integer"}}, "3", "x", - &schema{Type: "string"}, &schema{Type: "integer"}, true) // error: "x" is not an integer + testToolForSchema[in, any](t, &Tool{OutputSchema: outSchema2}, `{"p":3}`, out{true}, + inSchema, outSchema2, `want "integer"`) } diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index 7d777114..d06ea62b 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -15,7 +15,8 @@ import ( ) type AddParams struct { - X, Y int + X int `json:"x"` + Y int `json:"y"` } func Add(ctx context.Context, req *mcp.CallToolRequest, args AddParams) (*mcp.CallToolResult, any, error) { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 8817784d..e077308c 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -144,7 +144,7 @@ func TestStreamableTransports(t *testing.T) { // The "greet" tool should just work. params := &CallToolParams{ Name: "greet", - Arguments: map[string]any{"name": "foo"}, + Arguments: map[string]any{"Name": "foo"}, } got, err := session.CallTool(ctx, params) if err != nil { @@ -239,10 +239,11 @@ func TestStreamableServerShutdown(t *testing.T) { if err != nil { t.Fatal(err) } + defer clientSession.Close() params := &CallToolParams{ Name: "greet", - Arguments: map[string]any{"name": "foo"}, + Arguments: map[string]any{"Name": "foo"}, } // Verify that we can call a tool. if _, err := clientSession.CallTool(ctx, params); err != nil { diff --git a/mcp/testdata/conformance/server/tools.txtar b/mcp/testdata/conformance/server/tools.txtar index c39e3ec9..b582dda8 100644 --- a/mcp/testdata/conformance/server/tools.txtar +++ b/mcp/testdata/conformance/server/tools.txtar @@ -5,10 +5,17 @@ Fixed bugs: - "_meta" should not be nil - empty resource or prompts should not be returned as 'null' - the server should not crash when params are passed to tools/call +- missing required input fields should be rejected (#449) +- output and input should be validated against their actual json, not Go + representation +- When arguments are missing, the request should succeed if all properties are + optional, and observe any default values. -- tools -- greet structured +tomorrow +inc -- client -- { @@ -26,9 +33,14 @@ structured { "jsonrpc": "2.0", "id": 3, "method": "resources/list" } { "jsonrpc": "2.0", "id": 4, "method": "prompts/list" } { "jsonrpc": "2.0", "id": 5, "method": "tools/call" } -{ "jsonrpc": "2.0", "id": 6, "method": "tools/call", "params": {"name": "greet", "arguments": {"name": "you"} } } +{ "jsonrpc": "2.0", "id": 6, "method": "tools/call", "params": {"name": "greet", "arguments": {"Name": "you"} } } { "jsonrpc": "2.0", "id": 1, "result": {} } { "jsonrpc": "2.0", "id": 7, "method": "tools/call", "params": {"name": "structured", "arguments": {"In": "input"} } } +{ "jsonrpc": "2.0", "id": 8, "method": "tools/call", "params": {"name": "structured", "arguments": {} } } +{ "jsonrpc": "2.0", "id": 9, "method": "tools/call", "params": {"name": "tomorrow", "arguments": { "Now": "2025-06-18T15:04:05Z" } } } +{ "jsonrpc": "2.0", "id": 10, "method": "tools/call", "params": {"name": "greet" } } +{ "jsonrpc": "2.0", "id": 11, "method": "tools/call", "params": {"name": "inc", "arguments": { "x": 3 } } } +{ "jsonrpc": "2.0", "id": 11, "method": "tools/call", "params": {"name": "inc" } } -- server -- { @@ -69,6 +81,31 @@ structured }, "name": "greet" }, + { + "inputSchema": { + "type": "object", + "properties": { + "x": { + "type": "integer", + "default": 6 + } + }, + "additionalProperties": false + }, + "name": "inc", + "outputSchema": { + "type": "object", + "required": [ + "y" + ], + "properties": { + "y": { + "type": "integer" + } + }, + "additionalProperties": false + } + }, { "inputSchema": { "type": "object", @@ -97,6 +134,33 @@ structured }, "additionalProperties": false } + }, + { + "inputSchema": { + "type": "object", + "required": [ + "Now" + ], + "properties": { + "Now": { + "type": "string" + } + }, + "additionalProperties": false + }, + "name": "tomorrow", + "outputSchema": { + "type": "object", + "required": [ + "Tomorrow" + ], + "properties": { + "Tomorrow": { + "type": "string" + } + }, + "additionalProperties": false + } } ] } @@ -155,3 +219,64 @@ structured } } } +{ + "jsonrpc": "2.0", + "id": 8, + "error": { + "code": -32602, + "message": "invalid params: validating \"arguments\": validating root: required: missing properties: [\"In\"]" + } +} +{ + "jsonrpc": "2.0", + "id": 9, + "result": { + "content": [ + { + "type": "text", + "text": "{\"Tomorrow\":\"2025-06-19T15:04:05Z\"}" + } + ], + "structuredContent": { + "Tomorrow": "2025-06-19T15:04:05Z" + } + } +} +{ + "jsonrpc": "2.0", + "id": 10, + "error": { + "code": -32602, + "message": "invalid params: validating \"arguments\": validating root: required: missing properties: [\"Name\"]" + } +} +{ + "jsonrpc": "2.0", + "id": 11, + "result": { + "content": [ + { + "type": "text", + "text": "{\"y\":4}" + } + ], + "structuredContent": { + "y": 4 + } + } +} +{ + "jsonrpc": "2.0", + "id": 11, + "result": { + "content": [ + { + "type": "text", + "text": "{\"y\":7}" + } + ], + "structuredContent": { + "y": 7 + } + } +} diff --git a/mcp/tool.go b/mcp/tool.go index ffccbf30..12b02b7b 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -5,7 +5,6 @@ package mcp import ( - "bytes" "context" "encoding/json" "fmt" @@ -61,41 +60,44 @@ type serverTool struct { handler ToolHandler } -// unmarshalSchema unmarshals data into v and validates the result according to -// the given resolved schema. -func unmarshalSchema(data json.RawMessage, resolved *jsonschema.Resolved, v any) error { +// applySchema validates whether data is valid JSON according to the provided +// schema, after applying schema defaults. +// +// Returns the JSON value augmented with defaults. +func applySchema(data json.RawMessage, resolved *jsonschema.Resolved) (json.RawMessage, error) { // TODO: use reflection to create the struct type to unmarshal into. // Separate validation from assignment. - // Disallow unknown fields. - // Otherwise, if the tool was built with a struct, the client could send extra - // fields and json.Unmarshal would ignore them, so the schema would never get - // a chance to declare the extra args invalid. - dec := json.NewDecoder(bytes.NewReader(data)) - dec.DisallowUnknownFields() - if err := dec.Decode(v); err != nil { - return fmt.Errorf("unmarshaling: %w", err) - } - return validateSchema(resolved, v) -} - -func validateSchema(resolved *jsonschema.Resolved, value any) error { + // Use default JSON marshalling for validation. + // + // This avoids inconsistent representation due to custom marshallers, such as + // time.Time (issue #449). + // + // Additionally, unmarshalling into a map ensures that the resulting JSON is + // at least {}, even if data is empty. For example, arguments is technically + // an optional property of callToolParams, and we still want to apply the + // defaults in this case. + // + // TODO(rfindley): in which cases can resolved be nil? if resolved != nil { - if err := resolved.ApplyDefaults(value); err != nil { - return fmt.Errorf("applying defaults from \n\t%s\nto\n\t%v:\n%w", schemaJSON(resolved.Schema()), value, err) + v := make(map[string]any) + if len(data) > 0 { + if err := json.Unmarshal(data, &v); err != nil { + return nil, fmt.Errorf("unmarshaling arguments: %w", err) + } } - if err := resolved.Validate(value); err != nil { - return fmt.Errorf("validating\n\t%v\nagainst\n\t %s:\n %w", value, schemaJSON(resolved.Schema()), err) + if err := resolved.ApplyDefaults(&v); err != nil { + return nil, fmt.Errorf("applying schema defaults:\n%w", err) + } + if err := resolved.Validate(&v); err != nil { + return nil, err + } + // We must re-marshal with the default values applied. + var err error + data, err = json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("marshalling with defaults: %v", err) } } - return nil -} - -// schemaJSON returns the JSON value for s as a string, or a string indicating an error. -func schemaJSON(s *jsonschema.Schema) string { - m, err := json.Marshal(s) - if err != nil { - return fmt.Sprintf("", err) - } - return string(m) + return data, nil } diff --git a/mcp/tool_test.go b/mcp/tool_test.go index 2722a9ac..ef26e9dc 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -17,7 +17,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) -func TestUnmarshalSchema(t *testing.T) { +func TestApplySchema(t *testing.T) { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -39,20 +39,23 @@ func TestUnmarshalSchema(t *testing.T) { want any }{ {`{"x": 1}`, new(S), &S{X: 1}}, - {`{}`, new(S), &S{X: 3}}, // default applied - {`{"x": 0}`, new(S), &S{X: 3}}, // FAIL: should be 0. (requires double unmarshal) + {`{}`, new(S), &S{X: 3}}, // default applied + {`{"x": 0}`, new(S), &S{X: 0}}, {`{"x": 1}`, new(map[string]any), &map[string]any{"x": 1.0}}, {`{}`, new(map[string]any), &map[string]any{"x": 3.0}}, // default applied {`{"x": 0}`, new(map[string]any), &map[string]any{"x": 0.0}}, } { raw := json.RawMessage(tt.data) - if err := unmarshalSchema(raw, resolved, tt.v); err != nil { + raw, err = applySchema(raw, resolved) + if err != nil { + t.Fatal(err) + } + if err := json.Unmarshal(raw, &tt.v); err != nil { t.Fatal(err) } if !reflect.DeepEqual(tt.v, tt.want) { t.Errorf("got %#v, want %#v", tt.v, tt.want) } - } } From 159933cfa36628d1b2b6704152776356e51cb522 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 12 Sep 2025 14:42:07 +0000 Subject: [PATCH 193/221] mcp: speed up TestCommandTransportDuration At 7s, this test was 70% of our test execution time on my computer. Speed it up by mutating the defaultTerminateDuration for the purposes of the test. Also update our test workflow to use go1.25, now that it's out. --- .github/workflows/test.yml | 2 +- mcp/cmd.go | 4 +--- mcp/cmd_export_test.go | 20 ++++++++++++++++++++ mcp/cmd_test.go | 24 ++++++++++++++++++------ 4 files changed, 40 insertions(+), 10 deletions(-) create mode 100644 mcp/cmd_export_test.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d8e2d31e..9ad50c1d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,7 +39,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: ["1.23", "1.24", "1.25.0-rc.3"] + go: ["1.23", "1.24", "1.25"] steps: - name: Check out code uses: actions/checkout@v4 diff --git a/mcp/cmd.go b/mcp/cmd.go index 55e5cca6..b531eaf1 100644 --- a/mcp/cmd.go +++ b/mcp/cmd.go @@ -13,9 +13,7 @@ import ( "time" ) -const ( - defaultTerminateDuration = 5 * time.Second -) +var defaultTerminateDuration = 5 * time.Second // mutable for testing // A CommandTransport is a [Transport] that runs a command and communicates // with it over stdin/stdout, using newline-delimited JSON. diff --git a/mcp/cmd_export_test.go b/mcp/cmd_export_test.go new file mode 100644 index 00000000..331e8bd0 --- /dev/null +++ b/mcp/cmd_export_test.go @@ -0,0 +1,20 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import "time" + +// This file exports some helpers for mutating internals of the command +// transport for testing. + +// SetDefaultTerminateDuration sets the default command terminate duration, +// and returns a function to reset it to the default. +func SetDefaultTerminateDuration(d time.Duration) (reset func()) { + initial := defaultTerminateDuration + defaultTerminateDuration = d + return func() { + defaultTerminateDuration = initial + } +} diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index 146cbe1f..0df45708 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -257,25 +257,35 @@ func TestCommandTransportTerminateDuration(t *testing.T) { } requireExec(t) + // Unfortunately, since it does I/O, this test needs to rely on timing (we + // can't use synctest). However, we can still decreate the default + // termination duration to speed up the test. + const defaultDur = 50 * time.Millisecond + defer mcp.SetDefaultTerminateDuration(defaultDur)() + tests := []struct { name string duration time.Duration + wantMinDuration time.Duration wantMaxDuration time.Duration }{ { name: "default duration (zero)", duration: 0, - wantMaxDuration: 6 * time.Second, // default 5s + buffer + wantMinDuration: defaultDur, + wantMaxDuration: 1 * time.Second, // default + buffer }, { name: "below minimum duration", - duration: 500 * time.Millisecond, - wantMaxDuration: 6 * time.Second, // should use default 5s + buffer + duration: -500 * time.Millisecond, + wantMinDuration: defaultDur, + wantMaxDuration: 1 * time.Second, // should use default + buffer }, { name: "custom valid duration", - duration: 2 * time.Second, - wantMaxDuration: 3 * time.Second, // custom 2s + buffer + duration: 200 * time.Millisecond, + wantMinDuration: 200 * time.Millisecond, + wantMaxDuration: 1 * time.Second, // custom + buffer }, } @@ -306,7 +316,9 @@ func TestCommandTransportTerminateDuration(t *testing.T) { t.Fatalf("Close() failed with unexpected error: %v", err) } } - + if elapsed < tt.wantMinDuration { + t.Errorf("Close() took %v, expected at least %v", elapsed, tt.wantMinDuration) + } if elapsed > tt.wantMaxDuration { t.Errorf("Close() took %v, expected at most %v", elapsed, tt.wantMaxDuration) } From c2c810bf40eb390df87680736fa9efd642b399dc Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Fri, 12 Sep 2025 15:09:51 +0000 Subject: [PATCH 194/221] mcp: update jsonschema-go to v0.2.3, prepare README --- README.md | 4 ++-- go.mod | 2 +- go.sum | 2 ++ internal/readme/README.src.md | 4 ++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index d8bad49e..2888b740 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ -# MCP Go SDK v0.4.0 +# MCP Go SDK v0.5.0 [![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/modelcontextprotocol/go-sdk) @@ -7,7 +7,7 @@ This version contains breaking changes. See the [release notes]( -https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.4.0) for details. +https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.5.0) for details. [![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) diff --git a/go.mod b/go.mod index d303ef0c..56910893 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.23.0 require ( github.com/google/go-cmp v0.7.0 - github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2 + github.com/google/jsonschema-go v0.2.3 github.com/yosida95/uritemplate/v3 v3.0.2 golang.org/x/tools v0.34.0 ) diff --git a/go.sum b/go.sum index 6903b659..d30dd343 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2 h1:IIj7X4SH1HKy0WfPR4nNEj4dhIJWGdXM5YoBAbfpdoo= github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/jsonschema-go v0.2.3 h1:dkP3B96OtZKKFvdrUSaDkL+YDx8Uw9uC4Y+eukpCnmM= +github.com/google/jsonschema-go v0.2.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index 5ac7cb8d..45a57f65 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -1,4 +1,4 @@ -# MCP Go SDK v0.4.0 +# MCP Go SDK v0.5.0 [![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/modelcontextprotocol/go-sdk) @@ -6,7 +6,7 @@ This version contains breaking changes. See the [release notes]( -https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.4.0) for details. +https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.5.0) for details. [![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) From ab4d6218c6cecebc9e008fe94a6bff3106e91995 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 16 Sep 2025 07:58:37 -0400 Subject: [PATCH 195/221] mcp: tweak ProgressNotificationParams --- mcp/protocol.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mcp/protocol.go b/mcp/protocol.go index aeb9adbd..7be8ea17 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -625,15 +625,16 @@ type ProgressNotificationParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` + // The progress token which was given in the initial request, used to associate + // this notification with the request that is proceeding. + ProgressToken any `json:"progressToken"` // An optional message describing the current progress. Message string `json:"message,omitempty"` // The progress thus far. This should increase every time progress is made, even // if the total is unknown. Progress float64 `json:"progress"` - // The progress token which was given in the initial request, used to associate - // this notification with the request that is proceeding. - ProgressToken any `json:"progressToken"` // Total number of items to process (or total progress required), if known. + // Zero means unknown. Total float64 `json:"total,omitempty"` } From ac2f1752141b9ab81c8c26da375131aa779ef52a Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Tue, 16 Sep 2025 11:13:40 -0400 Subject: [PATCH 196/221] docs: add feature and troubleshooting documentation (#463) Add a framework for feature documentation, and start populating it with our SDK documentation. This framework is as follows: - internal/docs/**.src.md is the markdown source for the docs/ directory. - The x/example/internal/cmd/weave tool is used to compile these docs to the top-level docs, supporting both linked code samples and generated tables of contents. - The readme-check workflow is updated to check these docs as well. - The structure of these docs follows the MCP spec. - Wherever possible, example code is linked from actual Go documentation examples, so that it is testable. Some minor modifications to the weave tool were made to support this framework. Additionally, partially fill out this documentation with content on base protocol and client features, as well as troubleshooting help. Along the way, a bug was encountered that our LoggingTransport was not concurrency safe. This is fixed with a mutex. Fixes https://github.com/modelcontextprotocol/go-sdk/issues/466 Fixes https://github.com/modelcontextprotocol/go-sdk/issues/409 Updates https://github.com/modelcontextprotocol/go-sdk/issues/442 --- .github/workflows/readme-check.yml | 20 +- docs/README.md | 40 +++ docs/client.md | 168 ++++++++++++ docs/protocol.md | 368 +++++++++++++++++++++++++++ docs/server.md | 38 +++ docs/troubleshooting.md | 90 +++++++ internal/docs/README.src.md | 39 +++ internal/docs/client.src.md | 55 ++++ internal/docs/doc.go | 15 ++ internal/docs/protocol.src.md | 198 ++++++++++++++ internal/docs/server.src.md | 31 +++ internal/docs/troubleshooting.src.md | 42 +++ internal/readme/client/client.go | 2 +- mcp/client.go | 17 +- mcp/client_example_test.go | 136 ++++++++++ mcp/mcp_example_test.go | 166 ++++++++++++ mcp/mcp_test.go | 8 +- mcp/root.go | 5 - mcp/streamable.go | 16 +- mcp/streamable_example_test.go | 86 +++++++ mcp/transport.go | 16 +- mcp/transport_example_test.go | 40 +++ 22 files changed, 1566 insertions(+), 30 deletions(-) create mode 100644 docs/README.md create mode 100644 docs/client.md create mode 100644 docs/protocol.md create mode 100644 docs/server.md create mode 100644 docs/troubleshooting.md create mode 100644 internal/docs/README.src.md create mode 100644 internal/docs/client.src.md create mode 100644 internal/docs/doc.go create mode 100644 internal/docs/protocol.src.md create mode 100644 internal/docs/server.src.md create mode 100644 internal/docs/troubleshooting.src.md create mode 100644 mcp/client_example_test.go create mode 100644 mcp/mcp_example_test.go delete mode 100644 mcp/root.go create mode 100644 mcp/streamable_example_test.go create mode 100644 mcp/transport_example_test.go diff --git a/.github/workflows/readme-check.yml b/.github/workflows/readme-check.yml index bed3ff44..8709be11 100644 --- a/.github/workflows/readme-check.yml +++ b/.github/workflows/readme-check.yml @@ -1,11 +1,13 @@ name: README Check on: - workflow_dispatch: + workflow_dispatch: pull_request: paths: - 'internal/readme/**' - 'README.md' - + - 'internal/docs/**' + - 'docs/**' + permissions: contents: read @@ -17,15 +19,15 @@ jobs: uses: actions/setup-go@v5 - name: Check out code uses: actions/checkout@v4 - - name: Check README is up-to-date + - name: Check docs is up-to-date run: | - go generate ./internal/readme + go generate ./... if [ -n "$(git status --porcelain)" ]; then - echo "ERROR: README.md is not up-to-date!" + echo "ERROR: docs are not up-to-date!" echo "" - echo "The README.md file differs from what would be generated by `go generate ./internal/readme`." - echo "Please update internal/readme/README.src.md instead of README.md directly," - echo "then run `go generate ./internal/readme` to regenerate README.md." + echo "The docs differ from what would be generated by `go generate ./...`." + echo "Please update internal/**/*.src.md instead of directly editing README.md or docs/ files," + echo "then run `go generate ./...` to regenerate docs." echo "" echo "Changes:" git status --porcelain @@ -34,4 +36,4 @@ jobs: git diff exit 1 fi - echo "README.md is up-to-date" + echo "Docs are up-to-date." diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 00000000..81e7f5f7 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,40 @@ + +These docs are a work-in-progress. + +# Features + +These docs mirror the [official MCP spec](https://modelcontextprotocol.io/specification/2025-06-18). + +## Base Protocol + +1. [Lifecycle (Clients, Servers, and Sessions)](protocol.md#lifecycle). +1. [Transports](protocol.md#transports) + 1. [Stdio transport](protocol.md#stdio-transport) + 1. [Streamable transport](protocol.md#streamable-transport) + 1. [Custom transports](protocol.md#stateless-mode) +1. [Authorization](protocol.md#authorization) +1. [Security](protocol.md#security) +1. [Utilities](protocol.md#utilities) + 1. [Cancellation](utilities.md#cancellation) + 1. [Ping](utilities.md#ping) + 1. [Progress](utilities.md#progress) + +## Client Features + +1. [Roots](client.md#roots) +1. [Sampling](client.md#sampling) +1. [Elicitation](clients.md#elicitation) + +## Server Features + +1. [Prompts](server.md#prompts) +1. [Resources](server.md#resources) +1. [Tools](tools.md) +1. [Utilities](server.md#utilities) + 1. [Completion](server.md#completion) + 1. [Logging](server.md#logging) + 1. [Pagination](server.md#pagination) + +# TroubleShooting + +See [troubleshooting.md](troubleshooting.md) for a troubleshooting guide. diff --git a/docs/client.md b/docs/client.md new file mode 100644 index 00000000..13c12d57 --- /dev/null +++ b/docs/client.md @@ -0,0 +1,168 @@ + +# Support for MCP client features + +1. [Roots](#roots) +1. [Sampling](#sampling) +1. [Elicitation](#elicitation) + +## Roots + +MCP allows clients to specify a set of filesystem +["roots"](https://modelcontextprotocol.io/specification/2025-06-18/client/roots). +The SDK supports this as follows: + +**Client-side**: The SDK client always has the `roots.listChanged` capability. +To add roots to a client, use the +[`Client.AddRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#AddRoots) +and +[`Client.RemoveRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client.RemoveRoots) +methods. If any servers are already [connected](protocol.md#lifecycle) to the +client, a call to `AddRoot` or `RemoveRoots` will result in a +`notifications/roots/list_changed` notification to each connected server. + +**Server-side**: To query roots from the server, use the +[`ServerSession.ListRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.ListRoots) +method. To receive notifications about root changes, set +[`ServerOptions.RootsListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.RootsListChangedHandler). + +```go +func Example_roots() { + ctx := context.Background() + + // Create a client with a single root. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + c.AddRoots(&mcp.Root{URI: "file://a"}) + + // Now create a server with a handler to receive notifications about roots. + rootsChanged := make(chan struct{}) + handleRootsChanged := func(ctx context.Context, req *mcp.RootsListChangedRequest) { + rootList, err := req.Session.ListRoots(ctx, nil) + if err != nil { + log.Fatal(err) + } + var roots []string + for _, root := range rootList.Roots { + roots = append(roots, root.URI) + } + fmt.Println(roots) + close(rootsChanged) + } + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, &mcp.ServerOptions{ + RootsListChangedHandler: handleRootsChanged, + }) + + // Connect the server and client... + t1, t2 := mcp.NewInMemoryTransports() + if _, err := s.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + if _, err := c.Connect(ctx, t2, nil); err != nil { + log.Fatal(err) + } + + // ...and add a root. The server is notified about the change. + c.AddRoots(&mcp.Root{URI: "file://b"}) + <-rootsChanged + // Output: [file://a file://b] +} +``` + +## Sampling + +[Sampling](https://modelcontextprotocol.io/specification/2025-06-18/client/sampling) +is a way for servers to leverage the client's AI capabilities. It is +implemented in the SDK as follows: + +**Client-side**: To add the `sampling` capability to a client, set +[`ClientOptions.CreateMessageHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.CreateMessageHandler). +This function is invoked whenever the server requests sampling. + +**Server-side**: To use sampling from the server, call +[`ServerSession.CreateMessage`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.CreateMessage). + +```go +func Example_sampling() { + ctx := context.Background() + + // Create a client with a sampling handler. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, &mcp.ClientOptions{ + CreateMessageHandler: func(_ context.Context, req *mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + return &mcp.CreateMessageResult{ + Content: &mcp.TextContent{ + Text: "would have created a message", + }, + }, nil + }, + }) + + // Connect the server and client... + ct, st := mcp.NewInMemoryTransports() + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + session, err := s.Connect(ctx, st, nil) + if err != nil { + log.Fatal(err) + } + defer session.Close() + + if _, err := c.Connect(ctx, ct, nil); err != nil { + log.Fatal(err) + } + + msg, err := session.CreateMessage(ctx, &mcp.CreateMessageParams{}) + if err != nil { + log.Fatal(err) + } + fmt.Println(msg.Content.(*mcp.TextContent).Text) + // Output: would have created a message +} +``` + +## Elicitation + +[Elicitation](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation) +allows servers to request user inputs. It is implemented in the SDK as follows: + +**Client-side**: To add the `elicitation` capability to a client, set +[`ClientOptions.ElicitationHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ElicitationHandler). +The elicitation handler must return a result that matches the requested schema; +otherwise, elicitation returns an error. + +**Server-side**: To use elicitation from the server, call +[`ServerSession.Elicit`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.Elicit). + +```go +func Example_elicitation() { + ctx := context.Background() + ct, st := mcp.NewInMemoryTransports() + + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + log.Fatal(err) + } + defer ss.Close() + + c := mcp.NewClient(testImpl, &mcp.ClientOptions{ + ElicitationHandler: func(context.Context, *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + return &mcp.ElicitResult{Action: "accept", Content: map[string]any{"test": "value"}}, nil + }, + }) + if _, err := c.Connect(ctx, ct, nil); err != nil { + log.Fatal(err) + } + res, err := ss.Elicit(ctx, &mcp.ElicitParams{ + Message: "This should fail", + RequestedSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "test": {Type: "string"}, + }, + }, + }) + if err != nil { + log.Fatal(err) + } + fmt.Println(res.Content["test"]) + // Output: value +} +``` diff --git a/docs/protocol.md b/docs/protocol.md new file mode 100644 index 00000000..dbc4c1cb --- /dev/null +++ b/docs/protocol.md @@ -0,0 +1,368 @@ + +# Support for the MCP base protocol + +1. [Lifecycle](#lifecycle) +1. [Transports](#transports) + 1. [Stdio Transport](#stdio-transport) + 1. [Streamable Transport](#streamable-transport) + 1. [Custom transports](#custom-transports) + 1. [Concurrency](#concurrency) +1. [Authorization](#authorization) +1. [Security](#security) +1. [Utilities](#utilities) + 1. [Cancellation](#cancellation) + 1. [Ping](#ping) + 1. [Progress](#progress) + +## Lifecycle + +The SDK provides an API for defining both MCP clients and servers, and +connecting them over various transports. When a client and server are +connected, it creates a logical session, which follows the MCP spec's +[lifecycle](https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle). + +In this SDK, both a +[`Client`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client) +and +[`Server`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server) +can handle multiple peers. Every time a new peer is connected, it creates a new +session. + +- A `Client` is a logical MCP client, configured with various + [`ClientOptions`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions). +- When a client is connected to a server using + [`Client.Connect`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client.Connect), + it creates a + [`ClientSession`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession). + This session is initialized during the `Connect` method, and provides methods + to communicate with the server peer. +- A `Server` is a logical MCP server, configured with various + [`ServerOptions`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions). +- When a server is connected to a client using + [`Server.Connect`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.Connect), + it creates a + [`ServerSession`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession). + This session is not initialized until the client sends the + `notifications/initialized` message. Use `ServerOptions.InitializedHandler` + to listen for this event, or just use the session through various feature + handlers (such as a `ToolHandler`). Requests to the server are rejected until + the client has initialized the session. + +Both `ClientSession` and `ServerSession` have a `Close` method to terminate the +session, and a `Wait` method to await session termination by the peer. Typically, +it is the client's responsibility to end the session. + +```go +func Example_lifecycle() { + ctx := context.Background() + + // Create a client and server. + // Wait for the client to initialize the session. + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, &mcp.ServerOptions{ + InitializedHandler: func(context.Context, *mcp.InitializedRequest) { + fmt.Println("initialized!") + }, + }) + + // Connect the server and client using in-memory transports. + // + // Connect the server first so that it's ready to receive initialization + // messages from the client. + t1, t2 := mcp.NewInMemoryTransports() + serverSession, err := server.Connect(ctx, t1, nil) + if err != nil { + log.Fatal(err) + } + clientSession, err := client.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + + // Now shut down the session by closing the client, and waiting for the + // server session to end. + if err := clientSession.Close(); err != nil { + log.Fatal(err) + } + if err := serverSession.Wait(); err != nil { + log.Fatal(err) + } + // Output: initialized! +} +``` + +## Transports + +A +[transport](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports) +can be used to send JSON-RPC messages from client to server, or vice-versa. + +In the SDK, this is achieved by implementing the +[`Transport`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Transport) +interface, which creates a (logical) bidirectional stream of JSON-RPC messages. +Most transport implementations described below are specific to either the +client or server: a "client transport" is something that can be used to connect +a client to a server, and a "server transport" is something that can be used to +connect a server to a client. However, it's possible for a transport to be both +a client and server transport, such as the `InMemoryTransport` used in the +lifecycle example above. + +Transports should not be reused for multiple connections: if you need to create +multiple connections, use different transports. + +### Stdio Transport + +In the +[`stdio`](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#stdio) +transport clients communicate with an MCP server running in a subprocess using +newline-delimited JSON over its stdin/stdout. + +**Client-side**: the client side of the `stdio` transport is implemented by +[`CommandTransport`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#CommandTransport), +which starts the a `exec.Cmd` as a subprocess and communicates over its +stdin/stdout. + +**Server-side**: the server side of the `stdio` transport is implemented by +`StdioTransport`, which connects over the current processes `os.Stdin` and +`os.Stdout`. + +### Streamable Transport + +The [streamable +transport](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http) +API is implemented across three types: + +- `StreamableHTTPHandler`: an`http.Handler` that serves streamable MCP + sessions. +- `StreamableServerTransport`: a `Transport` that implements the server side of + the streamable transport. +- `StreamableClientTransport`: a `Transport` that implements the client side of + the streamable transport. + +To create a streamable MCP server, you create a `StreamableHTTPHandler` and +pass it an `mcp.Server`: + +```go +func ExampleStreamableHTTPHandler() { + // Create a new streamable handler, using the same MCP server for every request. + // + // Here, we configure it to serves application/json responses rather than + // text/event-stream, just so the output below doesn't use random event ids. + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.1.0"}, nil) + handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return server + }, &mcp.StreamableHTTPOptions{JSONResponse: true}) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + // The SDK is currently permissive of some missing keys in "params". + resp := mustPostMessage(`{"jsonrpc": "2.0", "id": 1, "method":"initialize", "params": {}}`, httpServer.URL) + fmt.Println(resp) + // Output: + // {"jsonrpc":"2.0","id":1,"result":{"capabilities":{"logging":{}},"protocolVersion":"2025-06-18","serverInfo":{"name":"server","version":"v0.1.0"}}} +} +``` + +The `StreamableHTTPHandler` handles the HTTP requests and creates a new +`StreamableServerTransport` for each new session. The transport is then used to +communicate with the client. + +On the client side, you create a `StreamableClientTransport` and use it to +connect to the server: + +```go +transport := &mcp.StreamableClientTransport{ + Endpoint: "http://localhost:8080/mcp", +} +client, err := mcp.Connect(ctx, transport, &mcp.ClientOptions{...}) +``` + +The `StreamableClientTransport` handles the HTTP requests and communicates with +the server using the streamable transport protocol. + +#### Stateless Mode + + + +#### Sessionless mode + + + +### Custom transports + + + +### Concurrency + +In general, MCP offers no guarantees about concurrency semantics: if a client +or server sends a notification, the spec says nothing about when the peer +observes that notification relative to other request. However, the Go SDK +implements the following heuristics: + +- If a notifying method (such as `notifications/progress` or + `notifications/initialized`) returns, then it is guaranteed that the peer + observes that notification before other notifications or calls from the same + client goroutine. +- Calls (such as `tools/call`) are handled asynchronously with respect to + each other. + +See +[modelcontextprotocol/go-sdk#26](https://github.com/modelcontextprotocol/go-sdk/issues/26) +for more background. + +## Authorization + + + +## Security + + + +## Utilities + +### Cancellation + +Cancellation is implemented with context cancellation. Cancelling a context +used in a method on `ClientSession` or `ServerSession` will terminate the RPC +and send a "notifications/cancelled" message to the peer. + +When an RPC exits due to a cancellation error, there's a guarantee that the +cancellation notification has been sent, but there's no guarantee that the +server has observed it (see [concurrency](#concurrency)). + +```go +func Example_cancellation() { + // For this example, we're going to be collecting observations from the + // server and client. + var clientResult, serverResult string + var wg sync.WaitGroup + wg.Add(2) + + // Create a server with a single slow tool. + // When the client cancels its request, the server should observe + // cancellation. + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + started := make(chan struct{}, 1) // signals that the server started handling the tool call + mcp.AddTool(server, &mcp.Tool{Name: "slow"}, func(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + started <- struct{}{} + defer wg.Done() + select { + case <-time.After(5 * time.Second): + serverResult = "tool done" + case <-ctx.Done(): + serverResult = "tool canceled" + } + return &mcp.CallToolResult{}, nil, nil + }) + + // Connect a client to the server. + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + ctx := context.Background() + t1, t2 := mcp.NewInMemoryTransports() + if _, err := server.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + session, err := client.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer session.Close() + + // Make a tool call, asynchronously. + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer wg.Done() + _, err = session.CallTool(ctx, &mcp.CallToolParams{Name: "slow"}) + clientResult = fmt.Sprintf("%v", err) + }() + + // As soon as the server has started handling the call, cancel it from the + // client side. + <-started + cancel() + wg.Wait() + + fmt.Println(clientResult) + fmt.Println(serverResult) + // Output: + // context canceled + // tool canceled +} +``` + +### Ping + +[Ping](https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/ping) +support is symmetrical for client and server. + +To initiate a ping, call +[`ClientSession.Ping`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Ping) +or +[`ServerSession.Ping`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.Ping). + +To have the client or server session automatically ping its peer, and close the +session if the ping fails, set +[`ClientOptions.KeepAlive`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.KeepAlive) +or +[`ServerOptions.KeepAlive`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.KeepAlive). + +### Progress + +[Progress](https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/progress) +reporting is possible by reading the progress token from request metadata and +calling either +[`ClientSession.NotifyProgress`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.NotifyProgress) +or +[`ServerSession.NotifyProgress`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.NotifyProgress). +To listen to progress notifications, set +[`ClientOptions.ProgressNotificationHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ProgressNotificationHandler) +or +[`ServerOptions.ProgressNotificationHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.ProgressNotificationHandler). + +Issue #460 discusses some potential ergonomic improvements to this API. + +```go +func Example_progress() { + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "makeProgress"}, func(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + if token := req.Params.GetProgressToken(); token != nil { + for i := range 3 { + params := &mcp.ProgressNotificationParams{ + Message: "frobbing widgets", + ProgressToken: token, + Progress: float64(i), + Total: 2, + } + req.Session.NotifyProgress(ctx, params) // ignore error + } + } + return &mcp.CallToolResult{}, nil, nil + }) + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, &mcp.ClientOptions{ + ProgressNotificationHandler: func(_ context.Context, req *mcp.ProgressNotificationClientRequest) { + fmt.Printf("%s %.0f/%.0f\n", req.Params.Message, req.Params.Progress, req.Params.Total) + }, + }) + ctx := context.Background() + t1, t2 := mcp.NewInMemoryTransports() + if _, err := server.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + + session, err := client.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer session.Close() + if _, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "makeProgress", + Meta: mcp.Meta{"progressToken": "abc123"}, + }); err != nil { + log.Fatal(err) + } + // Output: + // frobbing widgets 0/2 + // frobbing widgets 1/2 + // frobbing widgets 2/2 +} +``` diff --git a/docs/server.md b/docs/server.md new file mode 100644 index 00000000..5b17f2e5 --- /dev/null +++ b/docs/server.md @@ -0,0 +1,38 @@ + +# Support for MCP server features + +1. [Prompts](#prompts) +1. [Resources](#resources) +1. [Tools](#tools) +1. [Utilities](#utilities) + 1. [Completion](#completion) + 1. [Logging](#logging) + 1. [Pagination](#pagination) + +## Prompts + + + +## Resources + + + +## Tools + + + +## Utilities + + + +### Completion + + + +### Logging + + + +### Pagination + + diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md new file mode 100644 index 00000000..c0f021b6 --- /dev/null +++ b/docs/troubleshooting.md @@ -0,0 +1,90 @@ + +# Troubleshooting + +The Model Context Protocol is a complicated spec that leaves some room for +interpretation. Client and server SDKs can behave differently, or can be more +or less strict about their inputs. And of course, bugs happen. + +When you encounter a problem using the Go SDK, these instructions can help +collect information that will be useful in debugging. Please try to provide +this information in a bug report, so that maintainers can more quickly +understand what's going wrong. + +And most of all, please do [file bugs](https://github.com/modelcontextprotocol/go-sdk/issues/new?template=bug_report.md). + +## Using the MCP inspector + +To debug an MCP server, you can use the [MCP +inspector](https://modelcontextprotocol.io/legacy/tools/inspector). This is +useful for testing your server and verifying that it works with the typescript +SDK, as well as inspecting MCP traffic. + +## Collecting MCP logs + +For [stdio](protocol.md#stdio-transport) transport connections, you can also +inspect MCP traffic using a `LoggingTransport`: + +```go +func ExampleLoggingTransport() { + ctx := context.Background() + t1, t2 := mcp.NewInMemoryTransports() + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + if _, err := server.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + var b bytes.Buffer + logTransport := &mcp.LoggingTransport{Transport: t2, Writer: &b} + if _, err := client.Connect(ctx, logTransport, nil); err != nil { + log.Fatal(err) + } + fmt.Println(b.String()) + // Output: + // write: {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"capabilities":{"roots":{"listChanged":true}},"clientInfo":{"name":"client","version":"v0.0.1"},"protocolVersion":"2025-06-18"}} + // read: {"jsonrpc":"2.0","id":1,"result":{"capabilities":{"logging":{}},"protocolVersion":"2025-06-18","serverInfo":{"name":"server","version":"v0.0.1"}}} + // write: {"jsonrpc":"2.0","method":"notifications/initialized","params":{}} + +} +``` + +That example uses a `bytes.Buffer`, but you can also log to a file, or to +`os.Stderr`. + +## Inspecting HTTP traffic + +There are a couple different ways to investigate traffic to an HTTP transport +([streamable](protocol.md#streamable-transport) or legacy SSE). + +The first is to use an HTTP middleware: + +```go +func ExampleStreamableHTTPHandler_middleware() { + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.1.0"}, nil) + handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return server + }, nil) + loggingHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Example debugging; you could also capture the response. + body, err := io.ReadAll(req.Body) + if err != nil { + log.Fatal(err) + } + req.Body.Close() // ignore error + req.Body = io.NopCloser(bytes.NewBuffer(body)) + fmt.Println(req.Method, string(body)) + handler.ServeHTTP(w, req) + }) + httpServer := httptest.NewServer(loggingHandler) + defer httpServer.Close() + + // The SDK is currently permissive of some missing keys in "params". + mustPostMessage(`{"jsonrpc": "2.0", "id": 1, "method":"initialize", "params": {}}`, httpServer.URL) + // Output: + // POST {"jsonrpc": "2.0", "id": 1, "method":"initialize", "params": {}} +} +``` + +The second is to use a general purpose tool to inspect http traffic, such as +[wireshark](https://www.wireshark.org/) or +[tcpdump](https://linux.die.net/man/8/tcpdump). diff --git a/internal/docs/README.src.md b/internal/docs/README.src.md new file mode 100644 index 00000000..fb600df3 --- /dev/null +++ b/internal/docs/README.src.md @@ -0,0 +1,39 @@ +These docs are a work-in-progress. + +# Features + +These docs mirror the [official MCP spec](https://modelcontextprotocol.io/specification/2025-06-18). + +## Base Protocol + +1. [Lifecycle (Clients, Servers, and Sessions)](protocol.md#lifecycle). +1. [Transports](protocol.md#transports) + 1. [Stdio transport](protocol.md#stdio-transport) + 1. [Streamable transport](protocol.md#streamable-transport) + 1. [Custom transports](protocol.md#stateless-mode) +1. [Authorization](protocol.md#authorization) +1. [Security](protocol.md#security) +1. [Utilities](protocol.md#utilities) + 1. [Cancellation](utilities.md#cancellation) + 1. [Ping](utilities.md#ping) + 1. [Progress](utilities.md#progress) + +## Client Features + +1. [Roots](client.md#roots) +1. [Sampling](client.md#sampling) +1. [Elicitation](clients.md#elicitation) + +## Server Features + +1. [Prompts](server.md#prompts) +1. [Resources](server.md#resources) +1. [Tools](tools.md) +1. [Utilities](server.md#utilities) + 1. [Completion](server.md#completion) + 1. [Logging](server.md#logging) + 1. [Pagination](server.md#pagination) + +# TroubleShooting + +See [troubleshooting.md](troubleshooting.md) for a troubleshooting guide. diff --git a/internal/docs/client.src.md b/internal/docs/client.src.md new file mode 100644 index 00000000..f342719e --- /dev/null +++ b/internal/docs/client.src.md @@ -0,0 +1,55 @@ +# Support for MCP client features + +%toc + +## Roots + +MCP allows clients to specify a set of filesystem +["roots"](https://modelcontextprotocol.io/specification/2025-06-18/client/roots). +The SDK supports this as follows: + +**Client-side**: The SDK client always has the `roots.listChanged` capability. +To add roots to a client, use the +[`Client.AddRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#AddRoots) +and +[`Client.RemoveRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client.RemoveRoots) +methods. If any servers are already [connected](protocol.md#lifecycle) to the +client, a call to `AddRoot` or `RemoveRoots` will result in a +`notifications/roots/list_changed` notification to each connected server. + +**Server-side**: To query roots from the server, use the +[`ServerSession.ListRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.ListRoots) +method. To receive notifications about root changes, set +[`ServerOptions.RootsListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.RootsListChangedHandler). + +%include ../../mcp/client_example_test.go roots - + +## Sampling + +[Sampling](https://modelcontextprotocol.io/specification/2025-06-18/client/sampling) +is a way for servers to leverage the client's AI capabilities. It is +implemented in the SDK as follows: + +**Client-side**: To add the `sampling` capability to a client, set +[`ClientOptions.CreateMessageHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.CreateMessageHandler). +This function is invoked whenever the server requests sampling. + +**Server-side**: To use sampling from the server, call +[`ServerSession.CreateMessage`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.CreateMessage). + +%include ../../mcp/client_example_test.go sampling - + +## Elicitation + +[Elicitation](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation) +allows servers to request user inputs. It is implemented in the SDK as follows: + +**Client-side**: To add the `elicitation` capability to a client, set +[`ClientOptions.ElicitationHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ElicitationHandler). +The elicitation handler must return a result that matches the requested schema; +otherwise, elicitation returns an error. + +**Server-side**: To use elicitation from the server, call +[`ServerSession.Elicit`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.Elicit). + +%include ../../mcp/client_example_test.go elicitation - diff --git a/internal/docs/doc.go b/internal/docs/doc.go new file mode 100644 index 00000000..7b23ad63 --- /dev/null +++ b/internal/docs/doc.go @@ -0,0 +1,15 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:generate -command weave go run golang.org/x/example/internal/cmd/weave@latest +//go:generate weave -o ../../docs/README.md ./README.src.md +//go:generate weave -o ../../docs/protocol.md ./protocol.src.md +//go:generate weave -o ../../docs/client.md ./client.src.md +//go:generate weave -o ../../docs/server.md ./server.src.md +//go:generate weave -o ../../docs/troubleshooting.md ./troubleshooting.src.md + +// The doc package generates the documentation at /doc, via go:generate. +// +// Tests in this package are used for examples. +package docs diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md new file mode 100644 index 00000000..79b72418 --- /dev/null +++ b/internal/docs/protocol.src.md @@ -0,0 +1,198 @@ +# Support for the MCP base protocol + +%toc + +## Lifecycle + +The SDK provides an API for defining both MCP clients and servers, and +connecting them over various transports. When a client and server are +connected, it creates a logical session, which follows the MCP spec's +[lifecycle](https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle). + +In this SDK, both a +[`Client`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client) +and +[`Server`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server) +can handle multiple peers. Every time a new peer is connected, it creates a new +session. + +- A `Client` is a logical MCP client, configured with various + [`ClientOptions`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions). +- When a client is connected to a server using + [`Client.Connect`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client.Connect), + it creates a + [`ClientSession`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession). + This session is initialized during the `Connect` method, and provides methods + to communicate with the server peer. +- A `Server` is a logical MCP server, configured with various + [`ServerOptions`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions). +- When a server is connected to a client using + [`Server.Connect`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.Connect), + it creates a + [`ServerSession`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession). + This session is not initialized until the client sends the + `notifications/initialized` message. Use `ServerOptions.InitializedHandler` + to listen for this event, or just use the session through various feature + handlers (such as a `ToolHandler`). Requests to the server are rejected until + the client has initialized the session. + +Both `ClientSession` and `ServerSession` have a `Close` method to terminate the +session, and a `Wait` method to await session termination by the peer. Typically, +it is the client's responsibility to end the session. + +%include ../../mcp/mcp_example_test.go lifecycle - + +## Transports + +A +[transport](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports) +can be used to send JSON-RPC messages from client to server, or vice-versa. + +In the SDK, this is achieved by implementing the +[`Transport`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Transport) +interface, which creates a (logical) bidirectional stream of JSON-RPC messages. +Most transport implementations described below are specific to either the +client or server: a "client transport" is something that can be used to connect +a client to a server, and a "server transport" is something that can be used to +connect a server to a client. However, it's possible for a transport to be both +a client and server transport, such as the `InMemoryTransport` used in the +lifecycle example above. + +Transports should not be reused for multiple connections: if you need to create +multiple connections, use different transports. + +### Stdio Transport + +In the +[`stdio`](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#stdio) +transport clients communicate with an MCP server running in a subprocess using +newline-delimited JSON over its stdin/stdout. + +**Client-side**: the client side of the `stdio` transport is implemented by +[`CommandTransport`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#CommandTransport), +which starts the a `exec.Cmd` as a subprocess and communicates over its +stdin/stdout. + +**Server-side**: the server side of the `stdio` transport is implemented by +`StdioTransport`, which connects over the current processes `os.Stdin` and +`os.Stdout`. + +### Streamable Transport + +The [streamable +transport](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http) +API is implemented across three types: + +- `StreamableHTTPHandler`: an`http.Handler` that serves streamable MCP + sessions. +- `StreamableServerTransport`: a `Transport` that implements the server side of + the streamable transport. +- `StreamableClientTransport`: a `Transport` that implements the client side of + the streamable transport. + +To create a streamable MCP server, you create a `StreamableHTTPHandler` and +pass it an `mcp.Server`: + +%include ../../mcp/streamable_example_test.go streamablehandler - + +The `StreamableHTTPHandler` handles the HTTP requests and creates a new +`StreamableServerTransport` for each new session. The transport is then used to +communicate with the client. + +On the client side, you create a `StreamableClientTransport` and use it to +connect to the server: + +```go +transport := &mcp.StreamableClientTransport{ + Endpoint: "http://localhost:8080/mcp", +} +client, err := mcp.Connect(ctx, transport, &mcp.ClientOptions{...}) +``` + +The `StreamableClientTransport` handles the HTTP requests and communicates with +the server using the streamable transport protocol. + +#### Stateless Mode + + + +#### Sessionless mode + + + +### Custom transports + + + +### Concurrency + +In general, MCP offers no guarantees about concurrency semantics: if a client +or server sends a notification, the spec says nothing about when the peer +observes that notification relative to other request. However, the Go SDK +implements the following heuristics: + +- If a notifying method (such as `notifications/progress` or + `notifications/initialized`) returns, then it is guaranteed that the peer + observes that notification before other notifications or calls from the same + client goroutine. +- Calls (such as `tools/call`) are handled asynchronously with respect to + each other. + +See +[modelcontextprotocol/go-sdk#26](https://github.com/modelcontextprotocol/go-sdk/issues/26) +for more background. + +## Authorization + + + +## Security + + + +## Utilities + +### Cancellation + +Cancellation is implemented with context cancellation. Cancelling a context +used in a method on `ClientSession` or `ServerSession` will terminate the RPC +and send a "notifications/cancelled" message to the peer. + +When an RPC exits due to a cancellation error, there's a guarantee that the +cancellation notification has been sent, but there's no guarantee that the +server has observed it (see [concurrency](#concurrency)). + +%include ../../mcp/mcp_example_test.go cancellation - + +### Ping + +[Ping](https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/ping) +support is symmetrical for client and server. + +To initiate a ping, call +[`ClientSession.Ping`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Ping) +or +[`ServerSession.Ping`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.Ping). + +To have the client or server session automatically ping its peer, and close the +session if the ping fails, set +[`ClientOptions.KeepAlive`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.KeepAlive) +or +[`ServerOptions.KeepAlive`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.KeepAlive). + +### Progress + +[Progress](https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/progress) +reporting is possible by reading the progress token from request metadata and +calling either +[`ClientSession.NotifyProgress`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.NotifyProgress) +or +[`ServerSession.NotifyProgress`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.NotifyProgress). +To listen to progress notifications, set +[`ClientOptions.ProgressNotificationHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ProgressNotificationHandler) +or +[`ServerOptions.ProgressNotificationHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.ProgressNotificationHandler). + +Issue #460 discusses some potential ergonomic improvements to this API. + +%include ../../mcp/mcp_example_test.go progress - diff --git a/internal/docs/server.src.md b/internal/docs/server.src.md new file mode 100644 index 00000000..a131bcd3 --- /dev/null +++ b/internal/docs/server.src.md @@ -0,0 +1,31 @@ +# Support for MCP server features + +%toc + +## Prompts + + + +## Resources + + + +## Tools + + + +## Utilities + + + +### Completion + + + +### Logging + + + +### Pagination + + diff --git a/internal/docs/troubleshooting.src.md b/internal/docs/troubleshooting.src.md new file mode 100644 index 00000000..83342032 --- /dev/null +++ b/internal/docs/troubleshooting.src.md @@ -0,0 +1,42 @@ +# Troubleshooting + +The Model Context Protocol is a complicated spec that leaves some room for +interpretation. Client and server SDKs can behave differently, or can be more +or less strict about their inputs. And of course, bugs happen. + +When you encounter a problem using the Go SDK, these instructions can help +collect information that will be useful in debugging. Please try to provide +this information in a bug report, so that maintainers can more quickly +understand what's going wrong. + +And most of all, please do [file bugs](https://github.com/modelcontextprotocol/go-sdk/issues/new?template=bug_report.md). + +## Using the MCP inspector + +To debug an MCP server, you can use the [MCP +inspector](https://modelcontextprotocol.io/legacy/tools/inspector). This is +useful for testing your server and verifying that it works with the typescript +SDK, as well as inspecting MCP traffic. + +## Collecting MCP logs + +For [stdio](protocol.md#stdio-transport) transport connections, you can also +inspect MCP traffic using a `LoggingTransport`: + +%include ../../mcp/transport_example_test.go loggingtransport - + +That example uses a `bytes.Buffer`, but you can also log to a file, or to +`os.Stderr`. + +## Inspecting HTTP traffic + +There are a couple different ways to investigate traffic to an HTTP transport +([streamable](protocol.md#streamable-transport) or legacy SSE). + +The first is to use an HTTP middleware: + +%include ../../mcp/streamable_example_test.go httpmiddleware - + +The second is to use a general purpose tool to inspect http traffic, such as +[wireshark](https://www.wireshark.org/) or +[tcpdump](https://linux.die.net/man/8/tcpdump). diff --git a/internal/readme/client/client.go b/internal/readme/client/client.go index e2794f8b..9f357964 100644 --- a/internal/readme/client/client.go +++ b/internal/readme/client/client.go @@ -44,4 +44,4 @@ func main() { } } -//!- +// !- diff --git a/mcp/client.go b/mcp/client.go index 1ed3b048..822566de 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -55,11 +55,15 @@ func NewClient(impl *Implementation, opts *ClientOptions) *Client { // ClientOptions configures the behavior of the client. type ClientOptions struct { - // Handler for sampling. - // Called when a server calls CreateMessage. + // CreateMessageHandler handles incoming requests for sampling/createMessage. + // + // Setting CreateMessageHandler to a non-nil value causes the client to + // advertise the sampling capability. CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) - // Handler for elicitation. - // Called when a server requests user input via Elicit. + // ElicitationHandler handles incoming requests for elicitation/create. + // + // Setting ElicitationHandler to a non-nil value causes the client to + // advertise the elicitation capability. ElicitationHandler func(context.Context, *ElicitRequest) (*ElicitResult, error) // Handlers for notifications from the server. ToolListChangedHandler func(context.Context, *ToolListChangedRequest) @@ -123,7 +127,7 @@ func (c *Client) capabilities() *ClientCapabilities { } // Connect begins an MCP session by connecting to a server over the given -// transport, and initializing the session. +// transport. The resulting session is initialized, and ready to use. // // Typically, it is the responsibility of the client to close the connection // when it is no longer needed. However, if the connection is closed by the @@ -302,6 +306,9 @@ func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, // Validate elicitation result content against requested schema if req.Params.RequestedSchema != nil && res.Content != nil { + // TODO: is this the correct behavior if validation fails? + // It isn't the *server's* params that are invalid, so why would we return + // this code to the server? resolved, err := req.Params.RequestedSchema.Resolve(nil) if err != nil { return nil, jsonrpc2.NewError(CodeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err)) diff --git a/mcp/client_example_test.go b/mcp/client_example_test.go new file mode 100644 index 00000000..3c3c3837 --- /dev/null +++ b/mcp/client_example_test.go @@ -0,0 +1,136 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "context" + "fmt" + "log" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// !+roots + +func Example_roots() { + ctx := context.Background() + + // Create a client with a single root. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + c.AddRoots(&mcp.Root{URI: "file://a"}) + + // Now create a server with a handler to receive notifications about roots. + rootsChanged := make(chan struct{}) + handleRootsChanged := func(ctx context.Context, req *mcp.RootsListChangedRequest) { + rootList, err := req.Session.ListRoots(ctx, nil) + if err != nil { + log.Fatal(err) + } + var roots []string + for _, root := range rootList.Roots { + roots = append(roots, root.URI) + } + fmt.Println(roots) + close(rootsChanged) + } + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, &mcp.ServerOptions{ + RootsListChangedHandler: handleRootsChanged, + }) + + // Connect the server and client... + t1, t2 := mcp.NewInMemoryTransports() + if _, err := s.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + if _, err := c.Connect(ctx, t2, nil); err != nil { + log.Fatal(err) + } + + // ...and add a root. The server is notified about the change. + c.AddRoots(&mcp.Root{URI: "file://b"}) + <-rootsChanged + // Output: [file://a file://b] +} + +// !-roots + +// !+sampling + +func Example_sampling() { + ctx := context.Background() + + // Create a client with a sampling handler. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, &mcp.ClientOptions{ + CreateMessageHandler: func(_ context.Context, req *mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + return &mcp.CreateMessageResult{ + Content: &mcp.TextContent{ + Text: "would have created a message", + }, + }, nil + }, + }) + + // Connect the server and client... + ct, st := mcp.NewInMemoryTransports() + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + session, err := s.Connect(ctx, st, nil) + if err != nil { + log.Fatal(err) + } + defer session.Close() + + if _, err := c.Connect(ctx, ct, nil); err != nil { + log.Fatal(err) + } + + msg, err := session.CreateMessage(ctx, &mcp.CreateMessageParams{}) + if err != nil { + log.Fatal(err) + } + fmt.Println(msg.Content.(*mcp.TextContent).Text) + // Output: would have created a message +} + +// !-sampling + +// !+elicitation + +func Example_elicitation() { + ctx := context.Background() + ct, st := mcp.NewInMemoryTransports() + + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + log.Fatal(err) + } + defer ss.Close() + + c := mcp.NewClient(testImpl, &mcp.ClientOptions{ + ElicitationHandler: func(context.Context, *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + return &mcp.ElicitResult{Action: "accept", Content: map[string]any{"test": "value"}}, nil + }, + }) + if _, err := c.Connect(ctx, ct, nil); err != nil { + log.Fatal(err) + } + res, err := ss.Elicit(ctx, &mcp.ElicitParams{ + Message: "This should fail", + RequestedSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "test": {Type: "string"}, + }, + }, + }) + if err != nil { + log.Fatal(err) + } + fmt.Println(res.Content["test"]) + // Output: value +} + +// !-elicitation diff --git a/mcp/mcp_example_test.go b/mcp/mcp_example_test.go new file mode 100644 index 00000000..25f39fb8 --- /dev/null +++ b/mcp/mcp_example_test.go @@ -0,0 +1,166 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// !+lifecycle + +func Example_lifecycle() { + ctx := context.Background() + + // Create a client and server. + // Wait for the client to initialize the session. + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, &mcp.ServerOptions{ + InitializedHandler: func(context.Context, *mcp.InitializedRequest) { + fmt.Println("initialized!") + }, + }) + + // Connect the server and client using in-memory transports. + // + // Connect the server first so that it's ready to receive initialization + // messages from the client. + t1, t2 := mcp.NewInMemoryTransports() + serverSession, err := server.Connect(ctx, t1, nil) + if err != nil { + log.Fatal(err) + } + clientSession, err := client.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + + // Now shut down the session by closing the client, and waiting for the + // server session to end. + if err := clientSession.Close(); err != nil { + log.Fatal(err) + } + if err := serverSession.Wait(); err != nil { + log.Fatal(err) + } + // Output: initialized! +} + +// !-lifecycle + +// !+progress + +func Example_progress() { + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "makeProgress"}, func(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + if token := req.Params.GetProgressToken(); token != nil { + for i := range 3 { + params := &mcp.ProgressNotificationParams{ + Message: "frobbing widgets", + ProgressToken: token, + Progress: float64(i), + Total: 2, + } + req.Session.NotifyProgress(ctx, params) // ignore error + } + } + return &mcp.CallToolResult{}, nil, nil + }) + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, &mcp.ClientOptions{ + ProgressNotificationHandler: func(_ context.Context, req *mcp.ProgressNotificationClientRequest) { + fmt.Printf("%s %.0f/%.0f\n", req.Params.Message, req.Params.Progress, req.Params.Total) + }, + }) + ctx := context.Background() + t1, t2 := mcp.NewInMemoryTransports() + if _, err := server.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + + session, err := client.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer session.Close() + if _, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "makeProgress", + Meta: mcp.Meta{"progressToken": "abc123"}, + }); err != nil { + log.Fatal(err) + } + // Output: + // frobbing widgets 0/2 + // frobbing widgets 1/2 + // frobbing widgets 2/2 +} + +// !-progress + +// !+cancellation + +func Example_cancellation() { + // For this example, we're going to be collecting observations from the + // server and client. + var clientResult, serverResult string + var wg sync.WaitGroup + wg.Add(2) + + // Create a server with a single slow tool. + // When the client cancels its request, the server should observe + // cancellation. + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + started := make(chan struct{}, 1) // signals that the server started handling the tool call + mcp.AddTool(server, &mcp.Tool{Name: "slow"}, func(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + started <- struct{}{} + defer wg.Done() + select { + case <-time.After(5 * time.Second): + serverResult = "tool done" + case <-ctx.Done(): + serverResult = "tool canceled" + } + return &mcp.CallToolResult{}, nil, nil + }) + + // Connect a client to the server. + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + ctx := context.Background() + t1, t2 := mcp.NewInMemoryTransports() + if _, err := server.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + session, err := client.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer session.Close() + + // Make a tool call, asynchronously. + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer wg.Done() + _, err = session.CallTool(ctx, &mcp.CallToolParams{Name: "slow"}) + clientResult = fmt.Sprintf("%v", err) + }() + + // As soon as the server has started handling the call, cancel it from the + // client side. + <-started + cancel() + wg.Wait() + + fmt.Println(clientResult) + fmt.Println(serverResult) + // Output: + // context canceled + // tool canceled +} + +// !-cancellation diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 6191954c..dd542d3d 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -705,7 +705,7 @@ func TestCancellation(t *testing.T) { start = make(chan struct{}) cancelled = make(chan struct{}, 1) // don't block the request ) - slowRequest := func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { + slowTool := func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { start <- struct{}{} select { case <-ctx.Done(): @@ -716,7 +716,7 @@ func TestCancellation(t *testing.T) { return nil, nil, nil } cs, _ := basicConnection(t, func(s *Server) { - AddTool(s, &Tool{Name: "slow", InputSchema: &jsonschema.Schema{Type: "object"}}, slowRequest) + AddTool(s, &Tool{Name: "slow", InputSchema: &jsonschema.Schema{Type: "object"}}, slowTool) }) defer cs.Close() @@ -1109,7 +1109,7 @@ func TestElicitationSchemaValidation(t *testing.T) { "low", }, Extra: map[string]any{ - "enumNames": []interface{}{"High Priority", "Medium Priority", "Low Priority"}, + "enumNames": []any{"High Priority", "Medium Priority", "Low Priority"}, }, }, }, @@ -1270,7 +1270,7 @@ func TestElicitationSchemaValidation(t *testing.T) { "low", }, Extra: map[string]any{ - "enumNames": []interface{}{"High Priority", "Medium Priority"}, // Only 2 names for 3 values + "enumNames": []any{"High Priority", "Medium Priority"}, // Only 2 names for 3 values }, }, }, diff --git a/mcp/root.go b/mcp/root.go deleted file mode 100644 index b56ad991..00000000 --- a/mcp/root.go +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package mcp diff --git a/mcp/streamable.go b/mcp/streamable.go index bfaccae4..8ac6f59a 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -70,9 +70,12 @@ type StreamableHTTPOptions struct { // documentation for [StreamableServerTransport]. Stateless bool - // TODO: support session retention (?) + // TODO(#148): support session retention (?) - // JSONResponse is forwarded to StreamableServerTransport.jsonResponse. + // JSONResponse causes streamable responses to return application/json rather + // than text/event-stream ([§2.1.5] of the spec). + // + // [§2.1.5]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server JSONResponse bool } @@ -181,7 +184,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } - // Section 2.7 of the spec (2025-06-18) states: + // [§2.7] of the spec (2025-06-18) states: // // "If using HTTP, the client MUST include the MCP-Protocol-Version: // HTTP header on all subsequent requests to the MCP @@ -209,6 +212,8 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // assume 2025-03-26 if the client doesn't say anything). // // This logic matches the typescript SDK. + // + // [§2.7]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header protocolVersion := req.Header.Get(protocolVersionHeader) if protocolVersion == "" { protocolVersion = protocolVersion20250326 @@ -370,6 +375,9 @@ type StreamableServerTransport struct { // request contain only a single message. In this case, notifications or // requests made within the context of a server request will be sent to the // hanging GET request, if any. + // + // TODO(rfindley): jsonResponse should be exported, since + // StreamableHTTPOptions.JSONResponse is exported. jsonResponse bool // connection is non-nil if and only if the transport has been connected. @@ -1188,7 +1196,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return fmt.Errorf("%s: %v", requestSummary, err) } - // Section 2.5.3: "The server MAY terminate the session at any time, after + // §2.5.3: "The server MAY terminate the session at any time, after // which it MUST respond to requests containing that session ID with HTTP // 404 Not Found." if resp.StatusCode == http.StatusNotFound { diff --git a/mcp/streamable_example_test.go b/mcp/streamable_example_test.go new file mode 100644 index 00000000..430f2745 --- /dev/null +++ b/mcp/streamable_example_test.go @@ -0,0 +1,86 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "bytes" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// !+streamablehandler + +func ExampleStreamableHTTPHandler() { + // Create a new streamable handler, using the same MCP server for every request. + // + // Here, we configure it to serves application/json responses rather than + // text/event-stream, just so the output below doesn't use random event ids. + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.1.0"}, nil) + handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return server + }, &mcp.StreamableHTTPOptions{JSONResponse: true}) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + // The SDK is currently permissive of some missing keys in "params". + resp := mustPostMessage(`{"jsonrpc": "2.0", "id": 1, "method":"initialize", "params": {}}`, httpServer.URL) + fmt.Println(resp) + // Output: + // {"jsonrpc":"2.0","id":1,"result":{"capabilities":{"logging":{}},"protocolVersion":"2025-06-18","serverInfo":{"name":"server","version":"v0.1.0"}}} +} + +// !-streamablehandler + +// !+httpmiddleware + +func ExampleStreamableHTTPHandler_middleware() { + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.1.0"}, nil) + handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return server + }, nil) + loggingHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Example debugging; you could also capture the response. + body, err := io.ReadAll(req.Body) + if err != nil { + log.Fatal(err) + } + req.Body.Close() // ignore error + req.Body = io.NopCloser(bytes.NewBuffer(body)) + fmt.Println(req.Method, string(body)) + handler.ServeHTTP(w, req) + }) + httpServer := httptest.NewServer(loggingHandler) + defer httpServer.Close() + + // The SDK is currently permissive of some missing keys in "params". + mustPostMessage(`{"jsonrpc": "2.0", "id": 1, "method":"initialize", "params": {}}`, httpServer.URL) + // Output: + // POST {"jsonrpc": "2.0", "id": 1, "method":"initialize", "params": {}} +} + +// !-httpmiddleware + +func mustPostMessage(msg, url string) string { + req := orFatal(http.NewRequest("POST", url, strings.NewReader(msg))) + req.Header["Content-Type"] = []string{"application/json"} + req.Header["Accept"] = []string{"application/json", "text/event-stream"} + resp := orFatal(http.DefaultClient.Do(req)) + defer resp.Body.Close() + body := orFatal(io.ReadAll(resp.Body)) + return string(body) +} + +func orFatal[T any](t T, err error) T { + if err != nil { + log.Fatal(err) + } + return t +} diff --git a/mcp/transport.go b/mcp/transport.go index 024863de..d2109e7d 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -212,12 +212,14 @@ func (t *LoggingTransport) Connect(ctx context.Context) (Connection, error) { if err != nil { return nil, err } - return &loggingConn{delegate, t.Writer}, nil + return &loggingConn{delegate: delegate, w: t.Writer}, nil } type loggingConn struct { delegate Connection - w io.Writer + + mu sync.Mutex + w io.Writer } func (c *loggingConn) SessionID() string { return c.delegate.SessionID() } @@ -225,15 +227,21 @@ func (c *loggingConn) SessionID() string { return c.delegate.SessionID() } // Read is a stream middleware that logs incoming messages. func (s *loggingConn) Read(ctx context.Context) (jsonrpc.Message, error) { msg, err := s.delegate.Read(ctx) + if err != nil { + s.mu.Lock() fmt.Fprintf(s.w, "read error: %v", err) + s.mu.Unlock() } else { data, err := jsonrpc2.EncodeMessage(msg) + s.mu.Lock() if err != nil { fmt.Fprintf(s.w, "LoggingTransport: failed to marshal: %v", err) } fmt.Fprintf(s.w, "read: %s\n", string(data)) + s.mu.Unlock() } + return msg, err } @@ -241,13 +249,17 @@ func (s *loggingConn) Read(ctx context.Context) (jsonrpc.Message, error) { func (s *loggingConn) Write(ctx context.Context, msg jsonrpc.Message) error { err := s.delegate.Write(ctx, msg) if err != nil { + s.mu.Lock() fmt.Fprintf(s.w, "write error: %v", err) + s.mu.Unlock() } else { data, err := jsonrpc2.EncodeMessage(msg) + s.mu.Lock() if err != nil { fmt.Fprintf(s.w, "LoggingTransport: failed to marshal: %v", err) } fmt.Fprintf(s.w, "write: %s\n", string(data)) + s.mu.Unlock() } return err } diff --git a/mcp/transport_example_test.go b/mcp/transport_example_test.go new file mode 100644 index 00000000..dcf1a8ba --- /dev/null +++ b/mcp/transport_example_test.go @@ -0,0 +1,40 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "bytes" + "context" + "fmt" + "log" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// !+loggingtransport + +func ExampleLoggingTransport() { + ctx := context.Background() + t1, t2 := mcp.NewInMemoryTransports() + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + if _, err := server.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + var b bytes.Buffer + logTransport := &mcp.LoggingTransport{Transport: t2, Writer: &b} + if _, err := client.Connect(ctx, logTransport, nil); err != nil { + log.Fatal(err) + } + fmt.Println(b.String()) + // Output: + // write: {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"capabilities":{"roots":{"listChanged":true}},"clientInfo":{"name":"client","version":"v0.0.1"},"protocolVersion":"2025-06-18"}} + // read: {"jsonrpc":"2.0","id":1,"result":{"capabilities":{"logging":{}},"protocolVersion":"2025-06-18","serverInfo":{"name":"server","version":"v0.0.1"}}} + // write: {"jsonrpc":"2.0","method":"notifications/initialized","params":{}} + +} + +// !-loggingtransport From 845c29f6b7e3d71537c143b542c5d8c7badd3837 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 16 Sep 2025 18:56:28 +0000 Subject: [PATCH 197/221] mcp: add a benchmark for the MemoryEventStore; remove validation Add a new benchmark measuring the append and purge performance of the MemoryEventStore. This benchmark revealed that the store is orders of magnitude slower than it should be due to conservative validation (hugely so: 300KB/s vs 568MB/s). Turn off this validation by default. For #190 --- mcp/event.go | 4 ++-- mcp/event_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/mcp/event.go b/mcp/event.go index 0dd8734b..d309c4e0 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -23,8 +23,8 @@ import ( ) // If true, MemoryEventStore will do frequent validation to check invariants, slowing it down. -// Remove when we're confident in the code. -const validateMemoryEventStore = true +// Enable for debugging. +const validateMemoryEventStore = false // An Event is a server-sent event. // See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#fields. diff --git a/mcp/event_test.go b/mcp/event_test.go index 601e8300..20808c73 100644 --- a/mcp/event_test.go +++ b/mcp/event_test.go @@ -10,6 +10,7 @@ import ( "slices" "strings" "testing" + "time" ) func TestScanEvents(t *testing.T) { @@ -252,3 +253,46 @@ func TestMemoryEventStoreAfter(t *testing.T) { }) } } + +func BenchmarkMemoryEventStore(b *testing.B) { + // Benchmark with various settings for event store size, number of session, + // and payload size. + // + // Assume a small number of streams per session, which is probably realistic. + tests := []struct { + name string + limit int + sessions int + datasize int + }{ + {"1KB", 1024, 1, 16}, + {"1MB", 1024 * 1024, 10, 16}, + {"10MB", 10 * 1024 * 1024, 100, 16}, + {"10MB_big", 10 * 1024 * 1024, 1000, 128}, + } + + for _, test := range tests { + b.Run(test.name, func(b *testing.B) { + store := NewMemoryEventStore(nil) + store.SetMaxBytes(test.limit) + ctx := context.Background() + sessionIDs := make([]string, test.sessions) + streamIDs := make([][3]StreamID, test.sessions) + for i := range sessionIDs { + sessionIDs[i] = fmt.Sprint(i) + for j := range 3 { + streamIDs[i][j] = StreamID(randText()) + } + } + payload := make([]byte, test.datasize) + start := time.Now() + b.ResetTimer() + for i := 0; i < b.N; i++ { + sessionID := sessionIDs[i%len(sessionIDs)] + streamID := streamIDs[i%len(sessionIDs)][i%3] + store.Append(ctx, sessionID, streamID, payload) + } + b.ReportMetric(float64(test.datasize)*float64(b.N)/time.Since(start).Seconds(), "bytes/s") + }) + } +} From b3fb83f272a8fba10dd64568e53cb1f201f653ca Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 16 Sep 2025 21:12:14 -0400 Subject: [PATCH 198/221] internal/testing: fake auth server (#296) This is a fake OAuth authentication server, for use in testing. --- go.mod | 1 + go.sum | 2 + internal/testing/fake_auth_server.go | 151 +++++++++++++++++++++++++++ 3 files changed, 154 insertions(+) create mode 100644 internal/testing/fake_auth_server.go diff --git a/go.mod b/go.mod index 56910893..06896252 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/modelcontextprotocol/go-sdk go 1.23.0 require ( + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/go-cmp v0.7.0 github.com/google/jsonschema-go v0.2.3 github.com/yosida95/uritemplate/v3 v3.0.2 diff --git a/go.sum b/go.sum index d30dd343..7ccdc6a8 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2 h1:IIj7X4SH1HKy0WfPR4nNEj4dhIJWGdXM5YoBAbfpdoo= diff --git a/internal/testing/fake_auth_server.go b/internal/testing/fake_auth_server.go new file mode 100644 index 00000000..225f649c --- /dev/null +++ b/internal/testing/fake_auth_server.go @@ -0,0 +1,151 @@ +package testing + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +const ( + authServerPort = ":8080" + issuer = "http://localhost" + authServerPort + tokenExpiry = time.Hour +) + +var jwtSigningKey = []byte("fake-secret-key") + +type authCodeInfo struct { + codeChallenge string + redirectURI string +} + +// // FakeAuthServer is a fake OAuth2 authorization server. +// type FakeAuthServer struct { +// server *http.Server +// authCodes map[string]authCodeInfo +// } + +type state struct { + authCodes map[string]authCodeInfo +} + +// NewFakeAuthMux constructs a ServeMux that implements an OAuth 2.1 authentication +// server. It should be used with [httptest.NewTLSServer]. +func NewFakeAuthMux() *http.ServeMux { + s := &state{authCodes: make(map[string]authCodeInfo)} + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/oauth-authorization-server", s.handleMetadata) + mux.HandleFunc("/authorize", s.handleAuthorize) + mux.HandleFunc("/token", s.handleToken) + return mux +} + +func (s *state) handleMetadata(w http.ResponseWriter, r *http.Request) { + issuer := "https://localhost:" + r.URL.Port() + metadata := map[string]any{ + "issuer": issuer, + "authorization_endpoint": issuer + "/authorize", + "token_endpoint": issuer + "/token", + "jwks_uri": issuer + "/.well-known/jwks.json", + "scopes_supported": []string{"openid", "profile", "email"}, + "response_types_supported": []string{"code"}, + "grant_types_supported": []string{"authorization_code"}, + "token_endpoint_auth_methods_supported": []string{"none"}, + "code_challenge_methods_supported": []string{"S256"}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(metadata) +} + +func (s *state) handleAuthorize(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + responseType := query.Get("response_type") + redirectURI := query.Get("redirect_uri") + codeChallenge := query.Get("code_challenge") + codeChallengeMethod := query.Get("code_challenge_method") + + if responseType != "code" { + http.Error(w, "unsupported_response_type", http.StatusBadRequest) + return + } + if redirectURI == "" { + http.Error(w, "invalid_request (no redirect_uri)", http.StatusBadRequest) + return + } + if codeChallenge == "" || codeChallengeMethod != "S256" { + http.Error(w, "invalid_request (code challenge is not S256)", http.StatusBadRequest) + return + } + if query.Get("client_id") == "" { + http.Error(w, "invalid_request (missing client_id)", http.StatusBadRequest) + return + } + + authCode := "fake-auth-code-" + fmt.Sprintf("%d", time.Now().UnixNano()) + s.authCodes[authCode] = authCodeInfo{ + codeChallenge: codeChallenge, + redirectURI: redirectURI, + } + + redirectURL := fmt.Sprintf("%s?code=%s&state=%s", redirectURI, authCode, query.Get("state")) + http.Redirect(w, r, redirectURL, http.StatusFound) +} + +func (s *state) handleToken(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + grantType := r.Form.Get("grant_type") + code := r.Form.Get("code") + codeVerifier := r.Form.Get("code_verifier") + // Ignore redirect_uri; it is not required in 2.1. + // https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-13.html#redirect-uri-in-token-request + + if grantType != "authorization_code" { + http.Error(w, "unsupported_grant_type", http.StatusBadRequest) + return + } + + authCodeInfo, ok := s.authCodes[code] + if !ok { + http.Error(w, "invalid_grant", http.StatusBadRequest) + return + } + delete(s.authCodes, code) + + // PKCE verification + hasher := sha256.New() + hasher.Write([]byte(codeVerifier)) + calculatedChallenge := base64.RawURLEncoding.EncodeToString(hasher.Sum(nil)) + if calculatedChallenge != authCodeInfo.codeChallenge { + http.Error(w, "invalid_grant", http.StatusBadRequest) + return + } + + // Issue JWT + now := time.Now() + claims := jwt.MapClaims{ + "iss": issuer, + "sub": "fake-user-id", + "aud": "fake-client-id", + "exp": now.Add(tokenExpiry).Unix(), + "iat": now.Unix(), + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + accessToken, err := token.SignedString(jwtSigningKey) + if err != nil { + http.Error(w, "server_error", http.StatusInternalServerError) + return + } + + tokenResponse := map[string]any{ + "access_token": accessToken, + "token_type": "Bearer", + "expires_in": int(tokenExpiry.Seconds()), + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(tokenResponse) +} From bd6ffe3b4c17eb027a8fdf6badabcabe03b0a8c5 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 16 Sep 2025 17:27:58 +0000 Subject: [PATCH 199/221] mcp: add an example customizing a nested type schema Add an example that demonstrates how to customize a nested type schema, which is a recurring problem for our users (see #467). Fixes #467 Updates #368 --- mcp/tool_example_test.go | 93 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 mcp/tool_example_test.go diff --git a/mcp/tool_example_test.go b/mcp/tool_example_test.go new file mode 100644 index 00000000..29ccd9c2 --- /dev/null +++ b/mcp/tool_example_test.go @@ -0,0 +1,93 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "context" + "encoding/json" + "fmt" + "log" + "time" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func ExampleAddTool_customTypeSchema() { + // Sometimes when you want to customize the input or output schema for a + // tool, you need to customize the schema of a single helper type that's used + // in several places. + // + // For example, suppose you had a type that marshals/unmarshals like a + // time.Time, and that type was used multiple times in your tool input. + type MyDate struct { + time.Time + } + type Input struct { + Query string `json:"query,omitempty"` + Start MyDate `json:"start,omitempty"` + End MyDate `json:"end,omitempty"` + } + + // In this case, you can use jsonschema.For along with jsonschema.ForOptions + // to customize the schema inference for your custom type. + inputSchema, err := jsonschema.For[Input](&jsonschema.ForOptions{ + TypeSchemas: map[any]*jsonschema.Schema{ + MyDate{}: {Type: "string"}, + }, + }) + if err != nil { + log.Fatal(err) + } + + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + toolHandler := func(context.Context, *mcp.CallToolRequest, Input) (*mcp.CallToolResult, any, error) { + panic("not implemented") + } + mcp.AddTool(server, &mcp.Tool{Name: "my_tool", InputSchema: inputSchema}, toolHandler) + + ctx := context.Background() + session, err := connect(ctx, server) // create an in-memory connection + if err != nil { + log.Fatal(err) + } + defer session.Close() + + for t, err := range session.Tools(ctx, nil) { + if err != nil { + log.Fatal(err) + } + schemaJSON, err := json.MarshalIndent(t.InputSchema, "", "\t") + if err != nil { + log.Fatal(err) + } + fmt.Println(t.Name, string(schemaJSON)) + } + // Output: + // my_tool { + // "type": "object", + // "properties": { + // "end": { + // "type": "string" + // }, + // "query": { + // "type": "string" + // }, + // "start": { + // "type": "string" + // } + // }, + // "additionalProperties": false + // } +} + +func connect(ctx context.Context, server *mcp.Server) (*mcp.ClientSession, error) { + t1, t2 := mcp.NewInMemoryTransports() + if _, err := server.Connect(ctx, t1, nil); err != nil { + return nil, err + } + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + return client.Connect(ctx, t2, nil) +} From 22f86c4dfdf440e9980a18d9d3a7a89618f6be3f Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 16 Sep 2025 21:43:18 -0400 Subject: [PATCH 200/221] internal/testing: add copyright header (#477) --- examples/server/auth-middleware/go.mod | 2 +- examples/server/auth-middleware/go.sum | 4 ++-- examples/server/rate-limiting/go.mod | 2 +- examples/server/rate-limiting/go.sum | 4 ++-- internal/testing/fake_auth_server.go | 10 ++++------ 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/examples/server/auth-middleware/go.mod b/examples/server/auth-middleware/go.mod index f1b65d2b..f1ca77fa 100644 --- a/examples/server/auth-middleware/go.mod +++ b/examples/server/auth-middleware/go.mod @@ -8,7 +8,7 @@ require ( ) require ( - github.com/google/jsonschema-go v0.2.1 // indirect + github.com/google/jsonschema-go v0.2.3 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect ) diff --git a/examples/server/auth-middleware/go.sum b/examples/server/auth-middleware/go.sum index ada94c0c..6a392638 100644 --- a/examples/server/auth-middleware/go.sum +++ b/examples/server/auth-middleware/go.sum @@ -2,8 +2,8 @@ github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeD github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/jsonschema-go v0.2.1 h1:Z3iINWAUmvS4+m9cMP5lWbn6WlX8Hy4rpUS4pULVliQ= -github.com/google/jsonschema-go v0.2.1/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/jsonschema-go v0.2.3 h1:dkP3B96OtZKKFvdrUSaDkL+YDx8Uw9uC4Y+eukpCnmM= +github.com/google/jsonschema-go v0.2.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= diff --git a/examples/server/rate-limiting/go.mod b/examples/server/rate-limiting/go.mod index 8dd2f808..91d3d269 100644 --- a/examples/server/rate-limiting/go.mod +++ b/examples/server/rate-limiting/go.mod @@ -8,7 +8,7 @@ require ( ) require ( - github.com/google/jsonschema-go v0.2.1 // indirect + github.com/google/jsonschema-go v0.2.3 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect ) diff --git a/examples/server/rate-limiting/go.sum b/examples/server/rate-limiting/go.sum index 9e6a390e..da49fd16 100644 --- a/examples/server/rate-limiting/go.sum +++ b/examples/server/rate-limiting/go.sum @@ -1,7 +1,7 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/jsonschema-go v0.2.1 h1:Z3iINWAUmvS4+m9cMP5lWbn6WlX8Hy4rpUS4pULVliQ= -github.com/google/jsonschema-go v0.2.1/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/jsonschema-go v0.2.3 h1:dkP3B96OtZKKFvdrUSaDkL+YDx8Uw9uC4Y+eukpCnmM= +github.com/google/jsonschema-go v0.2.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= diff --git a/internal/testing/fake_auth_server.go b/internal/testing/fake_auth_server.go index 225f649c..79fafe4e 100644 --- a/internal/testing/fake_auth_server.go +++ b/internal/testing/fake_auth_server.go @@ -1,3 +1,7 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + package testing import ( @@ -24,12 +28,6 @@ type authCodeInfo struct { redirectURI string } -// // FakeAuthServer is a fake OAuth2 authorization server. -// type FakeAuthServer struct { -// server *http.Server -// authCodes map[string]authCodeInfo -// } - type state struct { authCodes map[string]authCodeInfo } From a3e753b8d8b5cb936c19f6110632b0c08377a8a4 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 17 Sep 2025 21:05:52 +0000 Subject: [PATCH 201/221] mcp: correct the JSON used for unstructured content Recently, I refactored to fix validation and default application for structured output. However, I used the wrong byte slice for the unstructured output. Fix it to use the same bytes as structured output, with a test. Fixes #475 --- mcp/server.go | 2 +- mcp/server_test.go | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 69808ac7..fd38a89a 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -292,7 +292,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content. if res.Content == nil { res.Content = []Content{&TextContent{ - Text: string(outbytes), + Text: string(outJSON), }} } } diff --git a/mcp/server_test.go b/mcp/server_test.go index 4456495f..249ef90b 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -512,8 +512,7 @@ func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out Arguments: json.RawMessage(in), }, } - _, err = goth(context.Background(), ctr) - + result, err := goth(context.Background(), ctr) if wantErrContaining != "" { if err == nil { t.Errorf("got nil error, want error containing %q", wantErrContaining) @@ -525,8 +524,18 @@ func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out } else if err != nil { t.Errorf("got error %v, want no error", err) } + + if gott.OutputSchema != nil && err == nil && !result.IsError { + // Check that structured content matches exactly. + unstructured := result.Content[0].(*TextContent).Text + structured := string(result.StructuredContent.(json.RawMessage)) + if diff := cmp.Diff(unstructured, structured); diff != "" { + t.Errorf("Unstructured content does not match structured content exactly (-unstructured +structured):\n%s", diff) + } + } } +// TODO: move this to tool_test.go func TestToolForSchemas(t *testing.T) { // Validate that toolForErr handles schemas properly. type in struct { @@ -558,4 +567,24 @@ func TestToolForSchemas(t *testing.T) { // Tool sets output schema: that is what's used, and validation happens. testToolForSchema[in, any](t, &Tool{OutputSchema: outSchema2}, `{"p":3}`, out{true}, inSchema, outSchema2, `want "integer"`) + + // Check a slightly more complicated case. + type weatherOutput struct { + Summary string + AsOf time.Time + Source string + } + testToolForSchema[any](t, &Tool{}, `{}`, weatherOutput{}, + &schema{Type: "object"}, + &schema{ + Type: "object", + Required: []string{"Summary", "AsOf", "Source"}, + AdditionalProperties: falseSchema, + Properties: map[string]*schema{ + "Summary": {Type: "string"}, + "AsOf": {Type: "string"}, + "Source": {Type: "string"}, + }, + }, + "") } From b615fa49b1a7b59be11ff7fa5d2bca5cf93ebcda Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Wed, 17 Sep 2025 17:49:31 -0400 Subject: [PATCH 202/221] mcp: add a streamable serving benchmark Add a benchmark using the streamable HTTP handler, and a more realistic input/output signature. Also expose this handler as an example (both a go example and a server in the examples/ directory). --- mcp/event_test.go | 2 +- mcp/streamable_bench_test.go | 66 +++++++++++++++++++ mcp/tool_example_test.go | 120 ++++++++++++++++++++++++++++++++++- 3 files changed, 186 insertions(+), 2 deletions(-) create mode 100644 mcp/streamable_bench_test.go diff --git a/mcp/event_test.go b/mcp/event_test.go index 20808c73..ef4e080b 100644 --- a/mcp/event_test.go +++ b/mcp/event_test.go @@ -287,7 +287,7 @@ func BenchmarkMemoryEventStore(b *testing.B) { payload := make([]byte, test.datasize) start := time.Now() b.ResetTimer() - for i := 0; i < b.N; i++ { + for i := range b.N { sessionID := sessionIDs[i%len(sessionIDs)] streamID := streamIDs[i%len(sessionIDs)][i%3] store.Append(ctx, sessionID, streamID, payload) diff --git a/mcp/streamable_bench_test.go b/mcp/streamable_bench_test.go new file mode 100644 index 00000000..07f81da3 --- /dev/null +++ b/mcp/streamable_bench_test.go @@ -0,0 +1,66 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func BenchmarkStreamableServing(b *testing.B) { + // This benchmark measures how fast we can handle a single tool on a + // streamable server, including tool validation and stream management. + customSchemas := map[any]*jsonschema.Schema{ + Probability(0): {Type: "number", Minimum: jsonschema.Ptr(0.0), Maximum: jsonschema.Ptr(1.0)}, + WeatherType(""): {Type: "string", Enum: []any{Sunny, PartlyCloudy, Cloudy, Rainy, Snowy}}, + } + opts := &jsonschema.ForOptions{TypeSchemas: customSchemas} + in, err := jsonschema.For[WeatherInput](opts) + if err != nil { + b.Fatal(err) + } + out, err := jsonschema.For[WeatherOutput](opts) + if err != nil { + b.Fatal(err) + } + + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + mcp.AddTool(server, &mcp.Tool{ + Name: "weather", + InputSchema: in, + OutputSchema: out, + }, WeatherTool) + + handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return server + }, &mcp.StreamableHTTPOptions{JSONResponse: true}) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + session, err := mcp.NewClient(testImpl, nil).Connect(ctx, &mcp.StreamableClientTransport{Endpoint: httpServer.URL}, nil) + if err != nil { + b.Fatal(err) + } + defer session.Close() + b.ResetTimer() + for range b.N { + if _, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "weather", + Arguments: WeatherInput{ + Location: Location{Name: "somewhere"}, + Days: 7, + }, + }); err != nil { + b.Errorf("CallTool failed: %v", err) + } + } +} diff --git a/mcp/tool_example_test.go b/mcp/tool_example_test.go index 29ccd9c2..888309bc 100644 --- a/mcp/tool_example_test.go +++ b/mcp/tool_example_test.go @@ -15,7 +15,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func ExampleAddTool_customTypeSchema() { +func ExampleAddTool_customMarshalling() { // Sometimes when you want to customize the input or output schema for a // tool, you need to customize the schema of a single helper type that's used // in several places. @@ -83,6 +83,124 @@ func ExampleAddTool_customTypeSchema() { // } } +type WeatherInput struct { + Location Location `json:"location" jsonschema:"user location"` + Days int `json:"days" jsonschema:"number of days to forecast"` +} + +type Location struct { + Name string `json:"name"` + Latitude *float64 `json:"latitude,omitempty"` + Longitude *float64 `json:"longitude,omitempty"` +} + +type Forecast struct { + Forecast string `json:"forecast" jsonschema:"description of the day's weather"` + Type WeatherType `json:"type" jsonschema:"type of weather"` + Rain float64 `json:"rain" jsonschema:"probability of rain, between 0 and 1"` + High float64 `json:"high" jsonschema:"high temperature"` + Low float64 `json:"low" jsonschema:"low temperature"` +} + +type WeatherType string + +const ( + Sunny WeatherType = "sun" + PartlyCloudy WeatherType = "partly_cloudy" + Cloudy WeatherType = "clouds" + Rainy WeatherType = "rain" + Snowy WeatherType = "snow" +) + +type Probability float64 + +type WeatherOutput struct { + Summary string `json:"summary" jsonschema:"a summary of the weather forecast"` + Confidence Probability `json:"confidence" jsonschema:"confidence, between 0 and 1"` + AsOf time.Time `json:"asOf" jsonschema:"the time the weather was computed"` + DailyForecast []Forecast `json:"dailyForecast" jsonschema:"the daily forecast"` + Source string `json:"source,omitempty" jsonschema:"the organization providing the weather forecast"` +} + +func WeatherTool(ctx context.Context, req *mcp.CallToolRequest, in WeatherInput) (*mcp.CallToolResult, WeatherOutput, error) { + perfectWeather := WeatherOutput{ + Summary: "perfect", + Confidence: 1.0, + AsOf: time.Now(), + } + for range in.Days { + perfectWeather.DailyForecast = append(perfectWeather.DailyForecast, Forecast{ + Forecast: "another perfect day", + Type: Sunny, + Rain: 0.0, + High: 72.0, + Low: 72.0, + }) + } + return nil, perfectWeather, nil +} + +func ExampleAddTool_complexSchema() { + // This example demonstrates a tool with a more 'realistic' input and output + // schema. We use a combination of techniques to tune our input and output + // schemas. + + // Distinguished Go types allow custom schemas to be reused during inference. + customSchemas := map[any]*jsonschema.Schema{ + Probability(0): {Type: "number", Minimum: jsonschema.Ptr(0.0), Maximum: jsonschema.Ptr(1.0)}, + WeatherType(""): {Type: "string", Enum: []any{Sunny, PartlyCloudy, Cloudy, Rainy, Snowy}}, + } + opts := &jsonschema.ForOptions{TypeSchemas: customSchemas} + in, err := jsonschema.For[WeatherInput](opts) + if err != nil { + log.Fatal(err) + } + + // Furthermore, we can tweak the inferred schema, in this case limiting + // forecasts to 0-10 days. + daysSchema := in.Properties["days"] + daysSchema.Minimum = jsonschema.Ptr(0.0) + daysSchema.Maximum = jsonschema.Ptr(10.0) + + // Output schema inference can reuse our custom schemas from input inference. + out, err := jsonschema.For[WeatherOutput](opts) + if err != nil { + log.Fatal(err) + } + + // Now add our tool to a server. Since we've customized the schemas, we need + // to override the default schema inference. + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + mcp.AddTool(server, &mcp.Tool{ + Name: "weather", + InputSchema: in, + OutputSchema: out, + }, WeatherTool) + + ctx := context.Background() + session, err := connect(ctx, server) // create an in-memory connection + if err != nil { + log.Fatal(err) + } + defer session.Close() + + // Check that the client observes the correct schemas. + for t, err := range session.Tools(ctx, nil) { + if err != nil { + log.Fatal(err) + } + // Formatting the entire schemas would be too much output. + // Just check that our customizations were effective. + fmt.Println("max days:", *t.InputSchema.Properties["days"].Maximum) + fmt.Println("max confidence:", *t.OutputSchema.Properties["confidence"].Maximum) + fmt.Println("weather types:", t.OutputSchema.Properties["dailyForecast"].Items.Properties["type"].Enum) + } + // Output: + // max days: 10 + // max confidence: 1 + // weather types: [sun partly_cloudy clouds rain snow] +} + func connect(ctx context.Context, server *mcp.Server) (*mcp.ClientSession, error) { t1, t2 := mcp.NewInMemoryTransports() if _, err := server.Connect(ctx, t1, nil); err != nil { From 140b93980c512d06a46732ed843f2beaa3bb81b8 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 18 Sep 2025 08:08:23 -0400 Subject: [PATCH 203/221] mcp: add CallToolResult.getError (#481) Provide a way for middleware to get the error from a tool call, demonstrating that we can add this functionality without making a breaking change. Leave getError unexported for now; we can export it (and setError) at any time. For #64. --- mcp/cmd_test.go | 2 +- mcp/content_nil_test.go | 4 ++- mcp/mcp_test.go | 75 ++++++++++++++++++++++++++++++++++++++--- mcp/protocol.go | 13 +++++++ mcp/protocol_test.go | 2 +- mcp/sse_test.go | 2 +- mcp/streamable_test.go | 4 +-- 7 files changed, 91 insertions(+), 11 deletions(-) diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index 0df45708..cbaadcb0 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -226,7 +226,7 @@ func TestCmdTransport(t *testing.T) { &mcp.TextContent{Text: "Hi user"}, }, } - if diff := cmp.Diff(want, got); diff != "" { + if diff := cmp.Diff(want, got, ctrCmpOpts...); diff != "" { t.Errorf("greet returned unexpected content (-want +got):\n%s", diff) } if err := session.Close(); err != nil { diff --git a/mcp/content_nil_test.go b/mcp/content_nil_test.go index 70cabfd7..8cc7bdbd 100644 --- a/mcp/content_nil_test.go +++ b/mcp/content_nil_test.go @@ -72,7 +72,7 @@ func TestContentUnmarshalNil(t *testing.T) { } // Verify that the Content field was properly populated - if cmp.Diff(tt.want, tt.content) != "" { + if cmp.Diff(tt.want, tt.content, ctrCmpOpts...) != "" { t.Errorf("Content is not equal: %v", cmp.Diff(tt.content, tt.content)) } }) @@ -222,3 +222,5 @@ func TestContentUnmarshalNilWithInvalidContent(t *testing.T) { }) } } + +var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(mcp.CallToolResult{})} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index dd542d3d..fa941bf0 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -234,7 +234,7 @@ func TestEndToEnd(t *testing.T) { &TextContent{Text: "hi user"}, }, } - if diff := cmp.Diff(wantHi, gotHi); diff != "" { + if diff := cmp.Diff(wantHi, gotHi, ctrCmpOpts...); diff != "" { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) } @@ -253,7 +253,7 @@ func TestEndToEnd(t *testing.T) { &TextContent{Text: errTestFailure.Error()}, }, } - if diff := cmp.Diff(wantFail, gotFail); diff != "" { + if diff := cmp.Diff(wantFail, gotFail, ctrCmpOpts...); diff != "" { t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) } @@ -1717,7 +1717,7 @@ func TestPointerArgEquivalence(t *testing.T) { if err != nil { t.Fatal(err) } - if diff := cmp.Diff(r0, r1); diff != "" { + if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" { t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff) } } @@ -1733,7 +1733,7 @@ func TestPointerArgEquivalence(t *testing.T) { if err != nil { t.Fatal(err) } - if diff := cmp.Diff(r0, r1); diff != "" { + if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" { t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff) } }) @@ -1837,3 +1837,70 @@ func TestEmbeddedStructResponse(t *testing.T) { t.Errorf("CallTool() failed: %v", err) } } + +func TestToolErrorMiddleware(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + s := NewServer(testImpl, nil) + AddTool(s, &Tool{ + Name: "greet", + Description: "say hi", + }, sayHi) + AddTool(s, &Tool{Name: "fail", InputSchema: &jsonschema.Schema{Type: "object"}}, + func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) { + return nil, nil, errTestFailure + }) + + var middleErr error + s.AddReceivingMiddleware(func(h MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + res, err := h(ctx, method, req) + if err == nil { + if ctr, ok := res.(*CallToolResult); ok { + middleErr = ctr.getError() + } + } + return res, err + } + }) + _, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + client := NewClient(&Implementation{Name: "test-client"}, nil) + clientSession, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer clientSession.Close() + + _, err = clientSession.CallTool(ctx, &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"Name": "al"}, + }) + if err != nil { + t.Errorf("CallTool() failed: %v", err) + } + if middleErr != nil { + t.Errorf("middleware got error %v, want nil", middleErr) + } + res, err := clientSession.CallTool(ctx, &CallToolParams{ + Name: "fail", + }) + if err != nil { + t.Errorf("CallTool() failed: %v", err) + } + if !res.IsError { + t.Fatal("want error, got none") + } + // Clients can't see the error, because it isn't marshaled. + if err := res.getError(); err != nil { + t.Fatalf("got %v, want nil", err) + } + if middleErr != errTestFailure { + t.Errorf("middleware got err %v, want errTestFailure", middleErr) + } +} + +var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(CallToolResult{})} diff --git a/mcp/protocol.go b/mcp/protocol.go index 7be8ea17..3e3c544e 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -103,6 +103,12 @@ type CallToolResult struct { // tool handler returns an error, and the error string is included as text in // the Content field. IsError bool `json:"isError,omitempty"` + + // The error passed to setError, if any. + // It is not marshaled, and therefore it is only visible on the server. + // Its only use is in server sending middleware, where it can be accessed + // with getError. + err error } // TODO(#64): consider exposing setError (and getError), by adding an error @@ -110,6 +116,13 @@ type CallToolResult struct { func (r *CallToolResult) setError(err error) { r.Content = []Content{&TextContent{Text: err.Error()}} r.IsError = true + r.err = err +} + +// getError returns the error set with setError, or nil if none. +// This function always returns nil on clients. +func (r *CallToolResult) getError() error { + return r.err } func (*CallToolResult) isResult() {} diff --git a/mcp/protocol_test.go b/mcp/protocol_test.go index 28e97518..67d021d1 100644 --- a/mcp/protocol_test.go +++ b/mcp/protocol_test.go @@ -499,7 +499,7 @@ func TestContentUnmarshal(t *testing.T) { if err := json.Unmarshal(data, out); err != nil { t.Fatal(err) } - if diff := cmp.Diff(in, out); diff != "" { + if diff := cmp.Diff(in, out, ctrCmpOpts...); diff != "" { t.Errorf("mismatch (-want, +got):\n%s", diff) } } diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 408e92ec..32a20bf3 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -70,7 +70,7 @@ func TestSSEServer(t *testing.T) { &TextContent{Text: "hi user"}, }, } - if diff := cmp.Diff(wantHi, gotHi); diff != "" { + if diff := cmp.Diff(wantHi, gotHi, ctrCmpOpts...); diff != "" { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index e077308c..79e9645f 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -159,7 +159,7 @@ func TestStreamableTransports(t *testing.T) { want := &CallToolResult{ Content: []Content{&TextContent{Text: "hi foo"}}, } - if diff := cmp.Diff(want, got); diff != "" { + if diff := cmp.Diff(want, got, ctrCmpOpts...); diff != "" { t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff) } @@ -550,8 +550,6 @@ func resp(id int64, result any, err error) *jsonrpc.Response { } } -var () - func TestStreamableServerTransport(t *testing.T) { // This test checks detailed behavior of the streamable server transport, by // faking the behavior of a streamable client using a sequence of HTTP From 208bfe25fd305457f4731e82519e16b93d952047 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 18 Sep 2025 08:25:24 -0400 Subject: [PATCH 204/221] oauthex: new package for oauth extensions (#429) This package will hold externally visible parts of internal/oauthex. For now, just add an alias to the ProtectedResourceMetadata struct. This lets one write MCP-compliant servers. --- README.md | 6 +++++- internal/readme/README.src.md | 6 +++++- oauthex/oauthex.go | 10 ++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) create mode 100644 oauthex/oauthex.go diff --git a/README.md b/README.md index 2888b740..ab989028 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ software development kit (SDK) for the Model Context Protocol (MCP). ## Package documentation -The SDK consists of three importable packages: +The SDK consists of several importable packages: - The [`github.com/modelcontextprotocol/go-sdk/mcp`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp) @@ -34,6 +34,10 @@ The SDK consists of three importable packages: [`github.com/modelcontextprotocol/go-sdk/auth`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth) package provides some primitives for supporting oauth. +- The + [`github.com/modelcontextprotocol/go-sdk/oauthex`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/oauthex) + package provides extensions to the OAuth protocol, such as ProtectedResourceMetadata. + ## Getting started To get started creating an MCP server, create an `mcp.Server` instance, add diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index 45a57f65..acd08e41 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -20,7 +20,7 @@ software development kit (SDK) for the Model Context Protocol (MCP). ## Package documentation -The SDK consists of three importable packages: +The SDK consists of several importable packages: - The [`github.com/modelcontextprotocol/go-sdk/mcp`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp) @@ -33,6 +33,10 @@ The SDK consists of three importable packages: [`github.com/modelcontextprotocol/go-sdk/auth`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth) package provides some primitives for supporting oauth. +- The + [`github.com/modelcontextprotocol/go-sdk/oauthex`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/oauthex) + package provides extensions to the OAuth protocol, such as ProtectedResourceMetadata. + ## Getting started To get started creating an MCP server, create an `mcp.Server` instance, add diff --git a/oauthex/oauthex.go b/oauthex/oauthex.go new file mode 100644 index 00000000..eb3a3e78 --- /dev/null +++ b/oauthex/oauthex.go @@ -0,0 +1,10 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package oauthex implements extensions to OAuth2. +package oauthex + +import "github.com/modelcontextprotocol/go-sdk/internal/oauthex" + +type ProtectedResourceMetadata = oauthex.ProtectedResourceMetadata From b636b1633b2c167ff4e2209ae1953a1c0843c303 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Thu, 18 Sep 2025 08:30:53 -0400 Subject: [PATCH 205/221] mcp: fix flaky ExampleLoggingTransport (#483) Reads are technically racy with writes. Sort for stability. --- docs/troubleshooting.md | 9 ++++++--- mcp/transport_example_test.go | 14 +++++++++++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index c0f021b6..0f990edc 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -39,12 +39,15 @@ func ExampleLoggingTransport() { if _, err := client.Connect(ctx, logTransport, nil); err != nil { log.Fatal(err) } - fmt.Println(b.String()) + // Sort for stability: reads are concurrent to writes. + for _, line := range slices.Sorted(strings.SplitSeq(b.String(), "\n")) { + fmt.Println(line) + } + // Output: - // write: {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"capabilities":{"roots":{"listChanged":true}},"clientInfo":{"name":"client","version":"v0.0.1"},"protocolVersion":"2025-06-18"}} // read: {"jsonrpc":"2.0","id":1,"result":{"capabilities":{"logging":{}},"protocolVersion":"2025-06-18","serverInfo":{"name":"server","version":"v0.0.1"}}} + // write: {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"capabilities":{"roots":{"listChanged":true}},"clientInfo":{"name":"client","version":"v0.0.1"},"protocolVersion":"2025-06-18"}} // write: {"jsonrpc":"2.0","method":"notifications/initialized","params":{}} - } ``` diff --git a/mcp/transport_example_test.go b/mcp/transport_example_test.go index dcf1a8ba..064ab0f2 100644 --- a/mcp/transport_example_test.go +++ b/mcp/transport_example_test.go @@ -2,6 +2,9 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. +// Uses strings.SplitSeq. +//go:build go1.24 + package mcp_test import ( @@ -9,6 +12,8 @@ import ( "context" "fmt" "log" + "slices" + "strings" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -29,12 +34,15 @@ func ExampleLoggingTransport() { if _, err := client.Connect(ctx, logTransport, nil); err != nil { log.Fatal(err) } - fmt.Println(b.String()) + // Sort for stability: reads are concurrent to writes. + for _, line := range slices.Sorted(strings.SplitSeq(b.String(), "\n")) { + fmt.Println(line) + } + // Output: - // write: {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"capabilities":{"roots":{"listChanged":true}},"clientInfo":{"name":"client","version":"v0.0.1"},"protocolVersion":"2025-06-18"}} // read: {"jsonrpc":"2.0","id":1,"result":{"capabilities":{"logging":{}},"protocolVersion":"2025-06-18","serverInfo":{"name":"server","version":"v0.0.1"}}} + // write: {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"capabilities":{"roots":{"listChanged":true}},"clientInfo":{"name":"client","version":"v0.0.1"},"protocolVersion":"2025-06-18"}} // write: {"jsonrpc":"2.0","method":"notifications/initialized","params":{}} - } // !-loggingtransport From d870f5e63813a9bd473255902def98d2d2a1b350 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 18 Sep 2025 10:22:57 -0400 Subject: [PATCH 206/221] build(deps): bump github.com/golang-jwt/jwt/v5 from 5.2.1 to 5.2.2 in the go_modules group across 1 directory (#476) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps the go_modules group with 1 update in the / directory: [github.com/golang-jwt/jwt/v5](https://github.com/golang-jwt/jwt). Updates `github.com/golang-jwt/jwt/v5` from 5.2.1 to 5.2.2
Release notes

Sourced from github.com/golang-jwt/jwt/v5's releases.

v5.2.2

What's Changed

New Contributors

Full Changelog: https://github.com/golang-jwt/jwt/compare/v5.2.1...v5.2.2

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github.com/golang-jwt/jwt/v5&package-manager=go_modules&previous-version=5.2.1&new-version=5.2.2)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore major version` will close this group update PR and stop Dependabot creating any more for the specific dependency's major version (unless you unignore this specific dependency's major version or upgrade to it yourself) - `@dependabot ignore minor version` will close this group update PR and stop Dependabot creating any more for the specific dependency's minor version (unless you unignore this specific dependency's minor version or upgrade to it yourself) - `@dependabot ignore ` will close this group update PR and stop Dependabot creating any more for the specific dependency (unless you unignore this specific dependency or upgrade to it yourself) - `@dependabot unignore ` will remove all of the ignore conditions of the specified dependency - `@dependabot unignore ` will remove the ignore condition of the specified dependency and ignore conditions You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/modelcontextprotocol/go-sdk/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 06896252..f5c578cf 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/modelcontextprotocol/go-sdk go 1.23.0 require ( - github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/go-cmp v0.7.0 github.com/google/jsonschema-go v0.2.3 github.com/yosida95/uritemplate/v3 v3.0.2 diff --git a/go.sum b/go.sum index 7ccdc6a8..6a392638 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,7 @@ -github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= -github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2 h1:IIj7X4SH1HKy0WfPR4nNEj4dhIJWGdXM5YoBAbfpdoo= -github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/jsonschema-go v0.2.3 h1:dkP3B96OtZKKFvdrUSaDkL+YDx8Uw9uC4Y+eukpCnmM= github.com/google/jsonschema-go v0.2.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= From f5cee0aa32d6aa906530e53361494aa21ffcd636 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Thu, 18 Sep 2025 11:22:40 -0400 Subject: [PATCH 207/221] docs: document stateless mode and custom transports (#485) --- docs/protocol.md | 32 ++++++++++++++++++++++++++------ internal/docs/protocol.src.md | 32 ++++++++++++++++++++++++++------ 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/docs/protocol.md b/docs/protocol.md index dbc4c1cb..75c5d127 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -182,15 +182,35 @@ the server using the streamable transport protocol. #### Stateless Mode - - -#### Sessionless mode - - +The streamable server supports a _stateless mode_ by setting +[`StreamableHTTPOptions.Stateless`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#StreamableHTTPOptions.Stateless), +which is where the server does not perform any validation of the session id, +and uses a temporary session to handle requests. In this mode, it is impossible +for the server to make client requests, as there is no way for the client's +response to reach the session. + +However, it is still possible for the server to access the `ServerSession.ID` +to see the logical session + +> [!WARNING] +> Stateless mode is not directly discussed in the spec, and is still being +> defined. See modelcontextprotocol/modelcontextprotocol#1364, +> modelcontextprotocol/modelcontextprotocol#1372, or +> modelcontextprotocol/modelcontextprotocol#11442 for potential refinements. + +_See [examples/server/distributed](../examples/server/distributed/main.go) for +an example using statless mode to implement a server distributed across +multiple processes._ ### Custom transports - +The SDK supports [custom +transports](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#custom-transports) +by implementing the +[`Transport`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Transport) +interface: a logical bidirectional stream of JSON-RPC messages. + +_Full example: [examples/server/custom-transport](../examples/server/custom-transport/main.go)._ ### Concurrency diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index 79b72418..2bb954bf 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -114,15 +114,35 @@ the server using the streamable transport protocol. #### Stateless Mode - - -#### Sessionless mode - - +The streamable server supports a _stateless mode_ by setting +[`StreamableHTTPOptions.Stateless`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#StreamableHTTPOptions.Stateless), +which is where the server does not perform any validation of the session id, +and uses a temporary session to handle requests. In this mode, it is impossible +for the server to make client requests, as there is no way for the client's +response to reach the session. + +However, it is still possible for the server to access the `ServerSession.ID` +to see the logical session + +> [!WARNING] +> Stateless mode is not directly discussed in the spec, and is still being +> defined. See modelcontextprotocol/modelcontextprotocol#1364, +> modelcontextprotocol/modelcontextprotocol#1372, or +> modelcontextprotocol/modelcontextprotocol#11442 for potential refinements. + +_See [examples/server/distributed](../examples/server/distributed/main.go) for +an example using statless mode to implement a server distributed across +multiple processes._ ### Custom transports - +The SDK supports [custom +transports](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#custom-transports) +by implementing the +[`Transport`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Transport) +interface: a logical bidirectional stream of JSON-RPC messages. + +_Full example: [examples/server/custom-transport](../examples/server/custom-transport/main.go)._ ### Concurrency From 1cfd489d722c9b9ae45d54141c5e83d0772ba816 Mon Sep 17 00:00:00 2001 From: Frenchie Date: Fri, 19 Sep 2025 01:25:19 +1000 Subject: [PATCH 208/221] chore(deps): add dependabot and SHA-pin workflow deps (#456) Introduces Dependabot & SHA-pins workflow deps, as per the [Discord thread](https://discord.com/channels/1358869848138059966/1399986970851282954/1415896798144368762). A note on the commit message: I used convention commit format out of habit, but can rewrite the it to adhere to the [styleguide](https://go.dev/wiki/CommitMessage) if critical. --- .github/workflows/docs-check.yml | 39 ++++++++++++++++++++++++++++++ .github/workflows/readme-check.yml | 39 ------------------------------ .github/workflows/test.yml | 14 +++++------ 3 files changed, 46 insertions(+), 46 deletions(-) create mode 100644 .github/workflows/docs-check.yml delete mode 100644 .github/workflows/readme-check.yml diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml new file mode 100644 index 00000000..cb6b841b --- /dev/null +++ b/.github/workflows/docs-check.yml @@ -0,0 +1,39 @@ +name: Docs Check +on: + workflow_dispatch: + pull_request: + paths: + - 'internal/readme/**' + - 'README.md' + - 'internal/docs/**' + - 'docs/**' + +permissions: + contents: read + +jobs: + docs-check: + runs-on: ubuntu-latest + steps: + - name: Set up Go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 + - name: Check out code + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 + - name: Check docs are up-to-date + run: | + go generate ./... + if [ -n "$(git status --porcelain)" ]; then + echo "ERROR: docs are not up-to-date!" + echo "" + echo "The docs differ from what would be generated by `go generate ./...`." + echo "Please update internal/**/*.src.md instead of directly editing README.md or docs/ files," + echo "then run `go generate ./...` to regenerate docs." + echo "" + echo "Changes:" + git status --porcelain + echo "" + echo "Diff:" + git diff + exit 1 + fi + echo "Docs are up-to-date." diff --git a/.github/workflows/readme-check.yml b/.github/workflows/readme-check.yml deleted file mode 100644 index 8709be11..00000000 --- a/.github/workflows/readme-check.yml +++ /dev/null @@ -1,39 +0,0 @@ -name: README Check -on: - workflow_dispatch: - pull_request: - paths: - - 'internal/readme/**' - - 'README.md' - - 'internal/docs/**' - - 'docs/**' - -permissions: - contents: read - -jobs: - readme-check: - runs-on: ubuntu-latest - steps: - - name: Set up Go - uses: actions/setup-go@v5 - - name: Check out code - uses: actions/checkout@v4 - - name: Check docs is up-to-date - run: | - go generate ./... - if [ -n "$(git status --porcelain)" ]; then - echo "ERROR: docs are not up-to-date!" - echo "" - echo "The docs differ from what would be generated by `go generate ./...`." - echo "Please update internal/**/*.src.md instead of directly editing README.md or docs/ files," - echo "then run `go generate ./...` to regenerate docs." - echo "" - echo "Changes:" - git status --porcelain - echo "" - echo "Diff:" - git diff - exit 1 - fi - echo "Docs are up-to-date." diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9ad50c1d..03662463 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,9 +14,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 with: go-version: "^1.23" - name: Check formatting @@ -31,7 +31,7 @@ jobs: - name: Run Go vet run: go vet ./... - name: Run staticcheck - uses: dominikh/staticcheck-action@v1 + uses: dominikh/staticcheck-action@024238d2898c874f26d723e7d0ff4308c35589a2 # v1 with: version: "latest" @@ -42,9 +42,9 @@ jobs: go: ["1.23", "1.24", "1.25"] steps: - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 with: go-version: ${{ matrix.go }} - name: Test @@ -54,9 +54,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 with: go-version: "1.24" - name: Test with -race From 353d46ff86f842a0d322778384b3cea6bc802edc Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Thu, 18 Sep 2025 12:01:07 -0400 Subject: [PATCH 209/221] mcp: remove the StreamID type (#490) The streamID type did not carry its weight. Remove it. Fixes #484. --- mcp/event.go | 26 ++++++++++++++------------ mcp/event_test.go | 10 +++++----- mcp/streamable.go | 30 +++++++++++++----------------- mcp/streamable_test.go | 2 +- 4 files changed, 33 insertions(+), 35 deletions(-) diff --git a/mcp/event.go b/mcp/event.go index d309c4e0..bd78cdee 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -156,11 +156,13 @@ type EventStore interface { // Open prepares the event store for a given stream. It ensures that the // underlying data structure for the stream is initialized, making it // ready to store event streams. - Open(_ context.Context, sessionID string, streamID StreamID) error + // + // streamIDs must be globally unique. + Open(_ context.Context, sessionID, streamID string) error // Append appends data for an outgoing event to given stream, which is part of the // given session. - Append(_ context.Context, sessionID string, _ StreamID, data []byte) error + Append(_ context.Context, sessionID, streamID string, data []byte) error // After returns an iterator over the data for the given session and stream, beginning // just after the given index. @@ -168,7 +170,7 @@ type EventStore interface { // After's iterator must return an error immediately if any data after index was // dropped; it must not return partial results. // The stream must have been opened previously (see [EventStore.Open]). - After(_ context.Context, sessionID string, _ StreamID, index int) iter.Seq2[[]byte, error] + After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] // SessionClosed informs the store that the given session is finished, along // with all of its streams. @@ -217,9 +219,9 @@ func (dl *dataList) removeFirst() int { // A MemoryEventStore is an [EventStore] backed by memory. type MemoryEventStore struct { mu sync.Mutex - maxBytes int // max total size of all data - nBytes int // current total size of all data - store map[string]map[StreamID]*dataList // session ID -> stream ID -> *dataList + maxBytes int // max total size of all data + nBytes int // current total size of all data + store map[string]map[string]*dataList // session ID -> stream ID -> *dataList } // MemoryEventStoreOptions are options for a [MemoryEventStore]. @@ -258,13 +260,13 @@ const defaultMaxBytes = 10 << 20 // 10 MiB func NewMemoryEventStore(opts *MemoryEventStoreOptions) *MemoryEventStore { return &MemoryEventStore{ maxBytes: defaultMaxBytes, - store: make(map[string]map[StreamID]*dataList), + store: make(map[string]map[string]*dataList), } } // Open implements [EventStore.Open]. It ensures that the underlying data // structures for the given session are initialized and ready for use. -func (s *MemoryEventStore) Open(_ context.Context, sessionID string, streamID StreamID) error { +func (s *MemoryEventStore) Open(_ context.Context, sessionID, streamID string) error { s.mu.Lock() defer s.mu.Unlock() s.init(sessionID, streamID) @@ -275,10 +277,10 @@ func (s *MemoryEventStore) Open(_ context.Context, sessionID string, streamID St // given sessionID and streamID exists, creating it if necessary. It returns the // dataList associated with the specified IDs. // Requires s.mu. -func (s *MemoryEventStore) init(sessionID string, streamID StreamID) *dataList { +func (s *MemoryEventStore) init(sessionID, streamID string) *dataList { streamMap, ok := s.store[sessionID] if !ok { - streamMap = make(map[StreamID]*dataList) + streamMap = make(map[string]*dataList) s.store[sessionID] = streamMap } dl, ok := streamMap[streamID] @@ -290,7 +292,7 @@ func (s *MemoryEventStore) init(sessionID string, streamID StreamID) *dataList { } // Append implements [EventStore.Append] by recording data in memory. -func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID StreamID, data []byte) error { +func (s *MemoryEventStore) Append(_ context.Context, sessionID, streamID string, data []byte) error { s.mu.Lock() defer s.mu.Unlock() dl := s.init(sessionID, streamID) @@ -307,7 +309,7 @@ func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID var ErrEventsPurged = errors.New("data purged") // After implements [EventStore.After]. -func (s *MemoryEventStore) After(_ context.Context, sessionID string, streamID StreamID, index int) iter.Seq2[[]byte, error] { +func (s *MemoryEventStore) After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] { // Return the data items to yield. // We must copy, because dataList.removeFirst nils out slice elements. copyData := func() ([][]byte, error) { diff --git a/mcp/event_test.go b/mcp/event_test.go index ef4e080b..dacf30e8 100644 --- a/mcp/event_test.go +++ b/mcp/event_test.go @@ -105,8 +105,8 @@ func TestScanEvents(t *testing.T) { func TestMemoryEventStoreState(t *testing.T) { ctx := context.Background() - appendEvent := func(s *MemoryEventStore, sess string, str StreamID, data string) { - if err := s.Append(ctx, sess, str, []byte(data)); err != nil { + appendEvent := func(s *MemoryEventStore, sess, stream string, data string) { + if err := s.Append(ctx, sess, stream, []byte(data)); err != nil { t.Fatal(err) } } @@ -218,7 +218,7 @@ func TestMemoryEventStoreAfter(t *testing.T) { for _, tt := range []struct { sessionID string - streamID StreamID + streamID string index int want []string wantErr string // if non-empty, error should contain this string @@ -277,11 +277,11 @@ func BenchmarkMemoryEventStore(b *testing.B) { store.SetMaxBytes(test.limit) ctx := context.Background() sessionIDs := make([]string, test.sessions) - streamIDs := make([][3]StreamID, test.sessions) + streamIDs := make([][3]string, test.sessions) for i := range sessionIDs { sessionIDs[i] = fmt.Sprint(i) for j := range 3 { - streamIDs[i][j] = StreamID(randText()) + streamIDs[i][j] = randText() } } payload := make([]byte, test.datasize) diff --git a/mcp/streamable.go b/mcp/streamable.go index 8ac6f59a..67f187ea 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -396,8 +396,8 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er jsonResponse: t.jsonResponse, incoming: make(chan jsonrpc.Message, 10), done: make(chan struct{}), - streams: make(map[StreamID]*stream), - requestStreams: make(map[jsonrpc.ID]StreamID), + streams: make(map[string]*stream), + requestStreams: make(map[jsonrpc.ID]string), } if t.connection.eventStore == nil { t.connection.eventStore = NewMemoryEventStore(nil) @@ -442,14 +442,14 @@ type streamableServerConn struct { // bound. If we deleted a stream when the response is sent, we would lose the ability // to replay if there was a cut just before the response was transmitted. // Perhaps we could have a TTL for streams that starts just after the response. - streams map[StreamID]*stream + streams map[string]*stream // requestStreams maps incoming requests to their logical stream ID. // // Lifecycle: requestStreams persist for the duration of the session. // // TODO: clean up once requests are handled. See the TODO for streams above. - requestStreams map[jsonrpc.ID]StreamID + requestStreams map[jsonrpc.ID]string } func (c *streamableServerConn) SessionID() string { @@ -466,7 +466,7 @@ func (c *streamableServerConn) SessionID() string { type stream struct { // id is the logical ID for the stream, unique within a session. // an empty string is used for messages that don't correlate with an incoming request. - id StreamID + id string // If isInitialize is set, the stream is in response to an initialize request, // and therefore should include the session ID header. @@ -500,7 +500,7 @@ type stream struct { requests map[jsonrpc.ID]struct{} } -func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, isInitialize, jsonResponse bool) (*stream, error) { +func (c *streamableServerConn) newStream(ctx context.Context, id string, isInitialize, jsonResponse bool) (*stream, error) { if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil { return nil, err } @@ -517,10 +517,6 @@ func signalChanPtr() *chan struct{} { return &c } -// A StreamID identifies a stream of SSE events. It is globally unique. -// [ServerSession]. -type StreamID string - // We track the incoming request ID inside the handler context using // idContextValue, so that notifications and server->client calls that occur in // the course of handling incoming requests are correlated with the incoming @@ -569,7 +565,7 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R // It returns an HTTP status code and error message. func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request) { // connID 0 corresponds to the default GET request. - id := StreamID("") + id := "" // By default, we haven't seen a last index. Since indices start at 0, we represent // that by -1. This is incremented just before each event is written, in streamResponse // around L407. @@ -669,7 +665,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // notifications or server->client requests made in the course of handling. // Update accounting for this incoming payload. if len(requests) > 0 { - stream, err = c.newStream(req.Context(), StreamID(randText()), isInitialize, c.jsonResponse) + stream, err = c.newStream(req.Context(), randText(), isInitialize, c.jsonResponse) if err != nil { http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) return @@ -860,7 +856,7 @@ func (c *streamableServerConn) messages(ctx context.Context, stream *stream, per // streamID and message index idx. // // See also [parseEventID]. -func formatEventID(sid StreamID, idx int) string { +func formatEventID(sid string, idx int) string { return fmt.Sprintf("%s_%d", sid, idx) } @@ -868,17 +864,17 @@ func formatEventID(sid StreamID, idx int) string { // index. // // See also [formatEventID]. -func parseEventID(eventID string) (sid StreamID, idx int, ok bool) { +func parseEventID(eventID string) (streamID string, idx int, ok bool) { parts := strings.Split(eventID, "_") if len(parts) != 2 { return "", 0, false } - stream := StreamID(parts[0]) + streamID = parts[0] idx, err := strconv.Atoi(parts[1]) if err != nil || idx < 0 { return "", 0, false } - return StreamID(stream), idx, true + return streamID, idx, true } // Read implements the [Connection] interface. @@ -922,7 +918,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e // // For messages sent outside of a request context, this is the default // connection "". - var forStream StreamID + var forStream string if forRequest.IsValid() { c.mu.Lock() forStream = c.requestStreams[forRequest] diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 79e9645f..e24478be 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1139,7 +1139,7 @@ func mustMarshal(v any) json.RawMessage { func TestEventID(t *testing.T) { tests := []struct { - sid StreamID + sid string idx int }{ {"0", 0}, From 02f0b2528642b41fa9be80f7f4773cbd3eb28994 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Friedrich=20Gro=C3=9Fe?= Date: Thu, 18 Sep 2025 18:21:02 +0200 Subject: [PATCH 210/221] Move GetSessionID closure into ServerOptions (#488) Fixes: #478. --- mcp/server.go | 15 +++++++++++++++ mcp/streamable.go | 15 +-------------- mcp/streamable_test.go | 27 ++++++++++++++++++--------- 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index fd38a89a..27de09a3 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -83,6 +83,16 @@ type ServerOptions struct { // If true, advertises the tools capability during initialization, // even if no tools have been registered. HasTools bool + + // GetSessionID provides the next session ID to use for an incoming request. + // If nil, a default randomly generated ID will be used. + // + // Session IDs should be globally unique across the scope of the server, + // which may span multiple processes in the case of distributed servers. + // + // As a special case, if GetSessionID returns the empty string, the + // Mcp-Session-Id header will not be set. + GetSessionID func() string } // NewServer creates a new MCP server. The resulting server has no features: @@ -114,6 +124,11 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil { panic("UnsubscribeHandler requires SubscribeHandler") } + + if opts.GetSessionID == nil { + opts.GetSessionID = randText + } + return &Server{ impl: impl, opts: opts, diff --git a/mcp/streamable.go b/mcp/streamable.go index 67f187ea..4ab343b2 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -50,16 +50,6 @@ type StreamableHTTPHandler struct { // StreamableHTTPOptions configures the StreamableHTTPHandler. type StreamableHTTPOptions struct { - // GetSessionID provides the next session ID to use for an incoming request. - // If nil, a default randomly generated ID will be used. - // - // Session IDs should be globally unique across the scope of the server, - // which may span multiple processes in the case of distributed servers. - // - // As a special case, if GetSessionID returns the empty string, the - // Mcp-Session-Id header will not be set. - GetSessionID func() string - // Stateless controls whether the session is 'stateless'. // // A stateless server does not validate the Mcp-Session-Id header, and uses a @@ -92,9 +82,6 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea if opts != nil { h.opts = *opts } - if h.opts.GetSessionID == nil { - h.opts.GetSessionID = randText - } return h } @@ -233,7 +220,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque if sessionID == "" { // In stateless mode, sessionID may be nonempty even if there's no // existing transport. - sessionID = h.opts.GetSessionID() + sessionID = server.opts.GetSessionID() } transport = &StreamableServerTransport{ SessionID: sessionID, diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index e24478be..3b967f8f 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -409,15 +409,14 @@ func testClientReplay(t *testing.T, test clientReplayTest) { } func TestServerTransportCleanup(t *testing.T) { - server := NewServer(testImpl, &ServerOptions{KeepAlive: 10 * time.Millisecond}) - nClient := 3 var mu sync.Mutex var id int = -1 // session id starting from "0", "1", "2"... chans := make(map[string]chan struct{}, nClient) - handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + server := NewServer(testImpl, &ServerOptions{ + KeepAlive: 10 * time.Millisecond, GetSessionID: func() string { mu.Lock() defer mu.Unlock() @@ -430,6 +429,7 @@ func TestServerTransportCleanup(t *testing.T) { }, }) + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil) handler.onTransportDeletion = func(sessionID string) { chans[sessionID] <- struct{}{} } @@ -1199,8 +1199,6 @@ func TestStreamableStateless(t *testing.T) { } return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + args.Name}}}, nil, nil } - server := NewServer(testImpl, nil) - AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) requests := []streamableRequest{ { @@ -1263,9 +1261,15 @@ func TestStreamableStateless(t *testing.T) { } } - sessionlessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ - GetSessionID: func() string { return "" }, - Stateless: true, + sessionlessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { + // Return a stateless server which never assigns a session ID. + server := NewServer(testImpl, &ServerOptions{ + GetSessionID: func() string { return "" }, + }) + AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + return server + }, &StreamableHTTPOptions{ + Stateless: true, }) // First, test the "sessionless" stateless mode, where there is no session ID. @@ -1279,7 +1283,12 @@ func TestStreamableStateless(t *testing.T) { // This can be used by tools to look up application state preserved across // subsequent requests. requests[0].wantSessionID = true // now expect a session ID for initialize - statelessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + statelessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { + // Return a server with default options which should assign a random session ID. + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + return server + }, &StreamableHTTPOptions{ Stateless: true, }) t.Run("stateless", func(t *testing.T) { From fed510663181dc66464f9f324ac694332dbdd8ef Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 18 Sep 2025 12:27:02 -0400 Subject: [PATCH 211/221] mcp: add prompt feature doc (#487) --- docs/client.md | 2 +- docs/server.md | 88 ++++++++++++++++++++++++++++++++++++- internal/docs/client.src.md | 2 +- internal/docs/server.src.md | 22 +++++++++- mcp/server_example_test.go | 83 ++++++++++++++++++++++++++++++++++ 5 files changed, 193 insertions(+), 4 deletions(-) create mode 100644 mcp/server_example_test.go diff --git a/docs/client.md b/docs/client.md index 13c12d57..cbc4db8f 100644 --- a/docs/client.md +++ b/docs/client.md @@ -13,7 +13,7 @@ The SDK supports this as follows: **Client-side**: The SDK client always has the `roots.listChanged` capability. To add roots to a client, use the -[`Client.AddRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#AddRoots) +[`Client.AddRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client.AddRoots) and [`Client.RemoveRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client.RemoveRoots) methods. If any servers are already [connected](protocol.md#lifecycle) to the diff --git a/docs/server.md b/docs/server.md index 5b17f2e5..e6af5d9c 100644 --- a/docs/server.md +++ b/docs/server.md @@ -11,7 +11,93 @@ ## Prompts - +**Server-side**: +MCP servers can provide LLM prompt templates (called simply _prompts_) to clients. +Every prompt has a required name which identifies it, and a set of named arguments, which are strings. +Construct a prompt with a name and descriptions of its arguments. +Associated with each prompt is a handler that expands the template given values for its arguments. +Use [`Server.AddPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddPrompt) +to add a prompt along with its handler. +If `AddPrompt` is called before a server is connected, the server will have the `prompts` capability. +If all prompts are to be added after connection, set [`ServerOptions.HasPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasPrompts) +to advertise the capability. + +**Client-side**: +To list the server's prompts, call +Call [`ClientSession.Prompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Prompts) to get an iterator. +If needed, you can use the lower-level +[`ClientSession.ListPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ListPrompts) to list the server's prompts. +Call [`ClientSession.GetPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.GetPrompt) to retrieve a prompt by name, providing +arguments for expansion. +Set [`ClientOptions.PromptListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.PromptListChangedHandler) to be notified of changes in the list of prompts. + +```go +func Example_prompts() { + ctx := context.Background() + + promptHandler := func(ctx context.Context, req *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return &mcp.GetPromptResult{ + Description: "Hi prompt", + Messages: []*mcp.PromptMessage{ + { + Role: "user", + Content: &mcp.TextContent{Text: "Say hi to " + req.Params.Arguments["name"]}, + }, + }, + }, nil + } + + // Create a server with a single prompt. + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + prompt := &mcp.Prompt{ + Name: "greet", + Arguments: []*mcp.PromptArgument{ + { + Name: "name", + Description: "the name of the person to greet", + Required: true, + }, + }, + } + s.AddPrompt(prompt, promptHandler) + + // Create a client. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + + // Connect the server and client. + t1, t2 := mcp.NewInMemoryTransports() + if _, err := s.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + cs, err := c.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + + // List the prompts. + for p, err := range cs.Prompts(ctx, nil) { + if err != nil { + log.Fatal(err) + } + fmt.Println(p.Name) + } + + // Get the prompt. + res, err := cs.GetPrompt(ctx, &mcp.GetPromptParams{ + Name: "greet", + Arguments: map[string]string{"name": "Pat"}, + }) + if err != nil { + log.Fatal(err) + } + for _, msg := range res.Messages { + fmt.Println(msg.Role, msg.Content.(*mcp.TextContent).Text) + } + // Output: + // greet + // user Say hi to Pat +} +``` ## Resources diff --git a/internal/docs/client.src.md b/internal/docs/client.src.md index f342719e..fc37d454 100644 --- a/internal/docs/client.src.md +++ b/internal/docs/client.src.md @@ -10,7 +10,7 @@ The SDK supports this as follows: **Client-side**: The SDK client always has the `roots.listChanged` capability. To add roots to a client, use the -[`Client.AddRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#AddRoots) +[`Client.AddRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client.AddRoots) and [`Client.RemoveRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client.RemoveRoots) methods. If any servers are already [connected](protocol.md#lifecycle) to the diff --git a/internal/docs/server.src.md b/internal/docs/server.src.md index a131bcd3..07ed6cb9 100644 --- a/internal/docs/server.src.md +++ b/internal/docs/server.src.md @@ -4,7 +4,27 @@ ## Prompts - +**Server-side**: +MCP servers can provide LLM prompt templates (called simply _prompts_) to clients. +Every prompt has a required name which identifies it, and a set of named arguments, which are strings. +Construct a prompt with a name and descriptions of its arguments. +Associated with each prompt is a handler that expands the template given values for its arguments. +Use [`Server.AddPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddPrompt) +to add a prompt along with its handler. +If `AddPrompt` is called before a server is connected, the server will have the `prompts` capability. +If all prompts are to be added after connection, set [`ServerOptions.HasPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasPrompts) +to advertise the capability. + +**Client-side**: +To list the server's prompts, call +Call [`ClientSession.Prompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Prompts) to get an iterator. +If needed, you can use the lower-level +[`ClientSession.ListPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ListPrompts) to list the server's prompts. +Call [`ClientSession.GetPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.GetPrompt) to retrieve a prompt by name, providing +arguments for expansion. +Set [`ClientOptions.PromptListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.PromptListChangedHandler) to be notified of changes in the list of prompts. + +%include ../../mcp/server_example_test.go prompts - ## Resources diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go new file mode 100644 index 00000000..d6b4d7e4 --- /dev/null +++ b/mcp/server_example_test.go @@ -0,0 +1,83 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "context" + "fmt" + "log" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// !+prompts + +func Example_prompts() { + ctx := context.Background() + + promptHandler := func(ctx context.Context, req *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return &mcp.GetPromptResult{ + Description: "Hi prompt", + Messages: []*mcp.PromptMessage{ + { + Role: "user", + Content: &mcp.TextContent{Text: "Say hi to " + req.Params.Arguments["name"]}, + }, + }, + }, nil + } + + // Create a server with a single prompt. + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + prompt := &mcp.Prompt{ + Name: "greet", + Arguments: []*mcp.PromptArgument{ + { + Name: "name", + Description: "the name of the person to greet", + Required: true, + }, + }, + } + s.AddPrompt(prompt, promptHandler) + + // Create a client. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + + // Connect the server and client. + t1, t2 := mcp.NewInMemoryTransports() + if _, err := s.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + cs, err := c.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + + // List the prompts. + for p, err := range cs.Prompts(ctx, nil) { + if err != nil { + log.Fatal(err) + } + fmt.Println(p.Name) + } + + // Get the prompt. + res, err := cs.GetPrompt(ctx, &mcp.GetPromptParams{ + Name: "greet", + Arguments: map[string]string{"name": "Pat"}, + }) + if err != nil { + log.Fatal(err) + } + for _, msg := range res.Messages { + fmt.Println(msg.Role, msg.Content.(*mcp.TextContent).Text) + } + // Output: + // greet + // user Say hi to Pat +} + +// !-prompts From 49d45a809707cbf19808614e5208012bed849da7 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Thu, 18 Sep 2025 13:31:54 -0400 Subject: [PATCH 212/221] docs: document pagination (#491) For #442 --- docs/server.md | 26 +++++++++++++++++++++++++- internal/docs/server.src.md | 26 +++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/docs/server.md b/docs/server.md index e6af5d9c..4d6192ff 100644 --- a/docs/server.md +++ b/docs/server.md @@ -121,4 +121,28 @@ func Example_prompts() { ### Pagination - +Server-side feature lists may be +[paginated](https://modelcontextprotocol.io/specification/2025-06-18/server/utilities/pagination), +using cursors. The SDK supports this by default. + +**Client-side**: The `ClientSession` provides methods returning +[iterators](https://go.dev/blog/range-functions) for each feature type. +These iterators are an `iter.Seq2[Feature, error]`, where the error value +indicates whether page retrieval failed. + +- [`ClientSession.Prompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Prompts) + iterates prompts. +- [`ClientSession.Resource`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Resource) + iterates resources. +- [`ClientSession.ResourceTemplates`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ResourceTemplates) + iterates resource templates. +- [`ClientSession.Tools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Tools) + iterates tools. + +The `ClientSession` also exposes `ListXXX` methods for fine-grained control +over pagination. + +**Server-side**: pagination is on by default, so in general nothing is required +server-side. However, you may use +[`ServerOptions.PageSize`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.PageSize) +to customize the page size. diff --git a/internal/docs/server.src.md b/internal/docs/server.src.md index 07ed6cb9..32f80e9d 100644 --- a/internal/docs/server.src.md +++ b/internal/docs/server.src.md @@ -48,4 +48,28 @@ Set [`ClientOptions.PromptListChangedHandler`](https://pkg.go.dev/github.com/mod ### Pagination - +Server-side feature lists may be +[paginated](https://modelcontextprotocol.io/specification/2025-06-18/server/utilities/pagination), +using cursors. The SDK supports this by default. + +**Client-side**: The `ClientSession` provides methods returning +[iterators](https://go.dev/blog/range-functions) for each feature type. +These iterators are an `iter.Seq2[Feature, error]`, where the error value +indicates whether page retrieval failed. + +- [`ClientSession.Prompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Prompts) + iterates prompts. +- [`ClientSession.Resource`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Resource) + iterates resources. +- [`ClientSession.ResourceTemplates`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ResourceTemplates) + iterates resource templates. +- [`ClientSession.Tools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Tools) + iterates tools. + +The `ClientSession` also exposes `ListXXX` methods for fine-grained control +over pagination. + +**Server-side**: pagination is on by default, so in general nothing is required +server-side. However, you may use +[`ServerOptions.PageSize`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.PageSize) +to customize the page size. From 1696b595c647035e01081254a5a44896b03c0cb4 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Thu, 18 Sep 2025 14:24:45 -0400 Subject: [PATCH 213/221] mcp: fix several goroutine leaks in tests (#494) Fix some (but not all) goroutine leaks in tests. Others were nontrivial, because they relate to streamable http. For #489. --- mcp/client_example_test.go | 5 +++- mcp/mcp_test.go | 49 +++++++++++++++++++++-------------- mcp/server_example_test.go | 1 + mcp/transport_example_test.go | 9 +++++-- mcp/transport_test.go | 1 + 5 files changed, 43 insertions(+), 22 deletions(-) diff --git a/mcp/client_example_test.go b/mcp/client_example_test.go index 3c3c3837..bba3da44 100644 --- a/mcp/client_example_test.go +++ b/mcp/client_example_test.go @@ -45,9 +45,12 @@ func Example_roots() { if _, err := s.Connect(ctx, t1, nil); err != nil { log.Fatal(err) } - if _, err := c.Connect(ctx, t2, nil); err != nil { + + clientSession, err := c.Connect(ctx, t2, nil) + if err != nil { log.Fatal(err) } + defer clientSession.Close() // ...and add a root. The server is notified about the change. c.AddRoots(&mcp.Root{URI: "file://b"}) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index fa941bf0..9cc105cd 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -592,7 +592,10 @@ func errorCode(err error) int64 { // // The caller should cancel either the client connection or server connection // when the connections are no longer needed. -func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *ServerSession) { +// +// The returned func cleans up by closing the client and waiting for the server +// to shut down. +func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *ServerSession, func()) { return basicClientServerConnection(t, nil, nil, config) } @@ -604,7 +607,10 @@ func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *Serve // // The caller should cancel either the client connection or server connection // when the connections are no longer needed. -func basicClientServerConnection(t *testing.T, client *Client, server *Server, config func(*Server)) (*ClientSession, *ServerSession) { +// +// The returned func cleans up by closing the client and waiting for the server +// to shut down. +func basicClientServerConnection(t *testing.T, client *Client, server *Server, config func(*Server)) (*ClientSession, *ServerSession, func()) { t.Helper() ctx := context.Background() @@ -628,14 +634,17 @@ func basicClientServerConnection(t *testing.T, client *Client, server *Server, c if err != nil { t.Fatal(err) } - return cs, ss + return cs, ss, func() { + cs.Close() + ss.Wait() + } } func TestServerClosing(t *testing.T) { - cs, ss := basicConnection(t, func(s *Server) { + cs, ss, cleanup := basicConnection(t, func(s *Server) { AddTool(s, greetTool(), sayHi) }) - defer cs.Close() + defer cleanup() ctx := context.Background() var wg sync.WaitGroup @@ -715,10 +724,10 @@ func TestCancellation(t *testing.T) { } return nil, nil, nil } - cs, _ := basicConnection(t, func(s *Server) { + cs, _, cleanup := basicConnection(t, func(s *Server) { AddTool(s, &Tool{Name: "slow", InputSchema: &jsonschema.Schema{Type: "object"}}, slowTool) }) - defer cs.Close() + defer cleanup() ctx, cancel := context.WithCancel(context.Background()) go cs.CallTool(ctx, &CallToolParams{Name: "slow"}) @@ -741,13 +750,10 @@ func TestMiddleware(t *testing.T) { t.Fatal(err) } // Wait for the server to exit after the client closes its connection. - var clientWG sync.WaitGroup - clientWG.Add(1) - go func() { + defer func() { if err := ss.Wait(); err != nil { t.Errorf("server failed: %v", err) } - clientWG.Done() }() var sbuf, cbuf bytes.Buffer @@ -767,6 +773,8 @@ func TestMiddleware(t *testing.T) { if err != nil { t.Fatal(err) } + defer cs.Close() + if _, err := cs.ListTools(ctx, nil); err != nil { t.Fatal(err) } @@ -1511,7 +1519,7 @@ func TestKeepAliveFailure(t *testing.T) { func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { // Adding the same tool pointer twice should not panic and should not // produce duplicates in the server's tool list. - cs, _ := basicConnection(t, func(s *Server) { + cs, _, cleanup := basicConnection(t, func(s *Server) { // Use two distinct Tool instances with the same name but different // descriptions to ensure the second replaces the first // This case was written specifically to reproduce a bug where duplicate tools where causing jsonschema errors @@ -1520,7 +1528,7 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { s.AddTool(t1, nopHandler) s.AddTool(t2, nopHandler) }) - defer cs.Close() + defer cleanup() ctx := context.Background() res, err := cs.ListTools(ctx, nil) @@ -1568,7 +1576,7 @@ func TestSynchronousNotifications(t *testing.T) { }, } server := NewServer(testImpl, serverOpts) - cs, ss := basicClientServerConnection(t, client, server, func(s *Server) { + cs, ss, cleanup := basicClientServerConnection(t, client, server, func(s *Server) { AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { if !rootsChanged.Load() { return nil, nil, fmt.Errorf("didn't get root change notification") @@ -1576,6 +1584,7 @@ func TestSynchronousNotifications(t *testing.T) { return new(CallToolResult), nil, nil }) }) + defer cleanup() t.Run("from client", func(t *testing.T) { client.AddRoots(&Root{Name: "myroot", URI: "file://foo"}) @@ -1617,7 +1626,7 @@ func TestNoDistributedDeadlock(t *testing.T) { }, } client := NewClient(testImpl, clientOpts) - cs, _ := basicClientServerConnection(t, client, nil, func(s *Server) { + cs, _, cleanup := basicClientServerConnection(t, client, nil, func(s *Server) { AddTool(s, &Tool{Name: "tool1"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { req.Session.CreateMessage(ctx, new(CreateMessageParams)) return new(CallToolResult), nil, nil @@ -1627,7 +1636,7 @@ func TestNoDistributedDeadlock(t *testing.T) { return new(CallToolResult), nil, nil }) }) - defer cs.Close() + defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -1651,7 +1660,7 @@ func TestPointerArgEquivalence(t *testing.T) { type output struct { Out string } - cs, _ := basicConnection(t, func(s *Server) { + cs, _, cleanup := basicConnection(t, func(s *Server) { // Add two equivalent tools, one of which operates in the 'pointer' realm, // the other of which does not. // @@ -1686,7 +1695,7 @@ func TestPointerArgEquivalence(t *testing.T) { } }) }) - defer cs.Close() + defer cleanup() ctx := context.Background() tools, err := cs.ListTools(ctx, nil) @@ -1758,7 +1767,9 @@ func TestComplete(t *testing.T) { }, } server := NewServer(testImpl, serverOpts) - cs, _ := basicClientServerConnection(t, nil, server, func(s *Server) {}) + cs, _, cleanup := basicClientServerConnection(t, nil, server, func(s *Server) {}) + defer cleanup() + result, err := cs.Complete(context.Background(), &CompleteParams{ Argument: CompleteParamsArgument{ Name: "language", diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index d6b4d7e4..16f19e20 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -55,6 +55,7 @@ func Example_prompts() { if err != nil { log.Fatal(err) } + defer cs.Close() // List the prompts. for p, err := range cs.Prompts(ctx, nil) { diff --git a/mcp/transport_example_test.go b/mcp/transport_example_test.go index 064ab0f2..ab54a422 100644 --- a/mcp/transport_example_test.go +++ b/mcp/transport_example_test.go @@ -24,16 +24,21 @@ func ExampleLoggingTransport() { ctx := context.Background() t1, t2 := mcp.NewInMemoryTransports() server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) - if _, err := server.Connect(ctx, t1, nil); err != nil { + serverSession, err := server.Connect(ctx, t1, nil) + if err != nil { log.Fatal(err) } + defer serverSession.Wait() client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) var b bytes.Buffer logTransport := &mcp.LoggingTransport{Transport: t2, Writer: &b} - if _, err := client.Connect(ctx, logTransport, nil); err != nil { + clientSession, err := client.Connect(ctx, logTransport, nil) + if err != nil { log.Fatal(err) } + defer clientSession.Close() + // Sort for stability: reads are concurrent to writes. for _, line := range slices.Sorted(strings.SplitSeq(b.String(), "\n")) { fmt.Println(line) diff --git a/mcp/transport_test.go b/mcp/transport_test.go index d40ce10f..10804a87 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -25,6 +25,7 @@ func TestBatchFraming(t *testing.T) { r, w := io.Pipe() tport := newIOConn(rwc{r, w}) tport.outgoingBatch = make([]jsonrpc.Message, 0, 2) + defer tport.Close() // Read the two messages into a channel, for easy testing later. read := make(chan jsonrpc.Message) From ff81f2f35a5873adcabf8b8aafda706c038a0534 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 18 Sep 2025 14:25:35 -0400 Subject: [PATCH 214/221] docs: add auth and security sections (#492) For #442. --- docs/protocol.md | 65 +++++++++++++++++++++++++++++++++-- internal/docs/protocol.src.md | 60 ++++++++++++++++++++++++++++++-- 2 files changed, 121 insertions(+), 4 deletions(-) diff --git a/docs/protocol.md b/docs/protocol.md index 75c5d127..e3b9e8c2 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -8,7 +8,12 @@ 1. [Custom transports](#custom-transports) 1. [Concurrency](#concurrency) 1. [Authorization](#authorization) + 1. [Server](#server) + 1. [Client](#client) 1. [Security](#security) + 1. [Confused Deputy](#confused-deputy) + 1. [Token Passthrough](#token-passthrough) + 1. [Session Hijacking](#session-hijacking) 1. [Utilities](#utilities) 1. [Cancellation](#cancellation) 1. [Ping](#ping) @@ -232,11 +237,67 @@ for more background. ## Authorization - +### Server + +To write an MCP server that performs authorization, +use [`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken). +This function is middleware that wraps an HTTP handler, such as the one returned +by [`NewStreamableHTTPHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#NewStreamableHTTPHandler), to provide support for verifying bearer tokens. +The middleware function checks every request for an Authorization header with a bearer token, +and invokes the +[`TokenVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenVerifier) + passed to `RequireBearerToken` to parse the token and perform validation. +The middleware function checks expiration and scopes (if they are provided in +[`RequireBearerTokenOptions.Scopes`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerTokenOptions.Scopes)), so the +`TokenVerifer` doesn't have to. +If [`RequireBearerTokenOptions.ResourceMetadataURL`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerTokenOptions.ResourceMetadataURL) is set and verification fails, +the middleware function sets the WWW-Authenticate header as required by the [Protected Resource +Metadata spec](https://datatracker.ietf.org/doc/html/rfc9728). + +The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/tree/main/examples/server/auth-middleware) shows how to implement authorization for both JWT tokens and API keys. + +### Client + +Client-side OAuth is implemented by setting +[`StreamableClientTransport.HTTPClient`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk@v0.5.0/mcp#StreamableClientTransport.HTTPClient) to a custom [`http.Client`](https://pkg.go.dev/net/http#Client) +Additional support is forthcoming; see #493. ## Security - +Here we discuss the mitigations described under +the MCP spec's [Security Best Practices](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices) section, and how we handle them. + +### Confused Deputy + +The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), obtaining user consent for dynamically registered clients, +happens on the MCP client. At present we don't provide client-side OAuth support. + + +### Token Passthrough + +The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-2), accepting only tokens that were issued for the server, depends on the structure +of tokens and is the responsibility of the +[`TokenVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenVerifier) +provided to +[`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken). + +### Session Hijacking + +The [mitigations](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-3) are as follows: + +- _Verify all inbound requests_. The [`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken) +middleware function will verify all HTTP requests that it receives. It is the +user's responsibility to wrap that function around all handlers in their server. + +- _Secure session IDs_. This SDK generates cryptographically secure session IDs by default. +If you create your own with +[`ServerOptions.GetSessionID`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.GetSessionID), it is your responsibility to ensure they are secure. +If you are using Go 1.24 or above, +we recommend using [`crypto/rand.Text`](https://pkg.go.dev/crypto/rand#Text) + +- _Binding session IDs to user information_. This is an application requirement, out of scope +for the SDK. You can create your own session IDs by setting +[`ServerOptions.GetSessionID`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.GetSessionID). ## Utilities diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index 2bb954bf..0abd5aa8 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -164,11 +164,67 @@ for more background. ## Authorization - +### Server + +To write an MCP server that performs authorization, +use [`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken). +This function is middleware that wraps an HTTP handler, such as the one returned +by [`NewStreamableHTTPHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#NewStreamableHTTPHandler), to provide support for verifying bearer tokens. +The middleware function checks every request for an Authorization header with a bearer token, +and invokes the +[`TokenVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenVerifier) + passed to `RequireBearerToken` to parse the token and perform validation. +The middleware function checks expiration and scopes (if they are provided in +[`RequireBearerTokenOptions.Scopes`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerTokenOptions.Scopes)), so the +`TokenVerifer` doesn't have to. +If [`RequireBearerTokenOptions.ResourceMetadataURL`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerTokenOptions.ResourceMetadataURL) is set and verification fails, +the middleware function sets the WWW-Authenticate header as required by the [Protected Resource +Metadata spec](https://datatracker.ietf.org/doc/html/rfc9728). + +The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/tree/main/examples/server/auth-middleware) shows how to implement authorization for both JWT tokens and API keys. + +### Client + +Client-side OAuth is implemented by setting +[`StreamableClientTransport.HTTPClient`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk@v0.5.0/mcp#StreamableClientTransport.HTTPClient) to a custom [`http.Client`](https://pkg.go.dev/net/http#Client) +Additional support is forthcoming; see #493. ## Security - +Here we discuss the mitigations described under +the MCP spec's [Security Best Practices](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices) section, and how we handle them. + +### Confused Deputy + +The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), obtaining user consent for dynamically registered clients, +happens on the MCP client. At present we don't provide client-side OAuth support. + + +### Token Passthrough + +The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-2), accepting only tokens that were issued for the server, depends on the structure +of tokens and is the responsibility of the +[`TokenVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenVerifier) +provided to +[`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken). + +### Session Hijacking + +The [mitigations](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-3) are as follows: + +- _Verify all inbound requests_. The [`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken) +middleware function will verify all HTTP requests that it receives. It is the +user's responsibility to wrap that function around all handlers in their server. + +- _Secure session IDs_. This SDK generates cryptographically secure session IDs by default. +If you create your own with +[`ServerOptions.GetSessionID`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.GetSessionID), it is your responsibility to ensure they are secure. +If you are using Go 1.24 or above, +we recommend using [`crypto/rand.Text`](https://pkg.go.dev/crypto/rand#Text) + +- _Binding session IDs to user information_. This is an application requirement, out of scope +for the SDK. You can create your own session IDs by setting +[`ServerOptions.GetSessionID`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.GetSessionID). ## Utilities From 28753c9f1227c253f12eabb74a3e20d2a6e91016 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Thu, 18 Sep 2025 14:27:35 -0400 Subject: [PATCH 215/221] docs: document completion support (#495) --- docs/server.md | 46 ++++++++++++++++++++++++++++-- examples/server/completion/main.go | 2 ++ internal/docs/server.src.md | 18 ++++++++++-- 3 files changed, 60 insertions(+), 6 deletions(-) diff --git a/docs/server.md b/docs/server.md index 4d6192ff..f59e2c7e 100644 --- a/docs/server.md +++ b/docs/server.md @@ -109,11 +109,51 @@ func Example_prompts() { ## Utilities - - ### Completion - +To support the +[completion](https://modelcontextprotocol.io/specification/2025-06-18/server/utilities/completion) +capability, the server needs a completion handler. + +**Client-side**: completion is called using the +[`ClientSession.Complete`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Complete) +method. + +**Server-side**: completion is enabled by setting +[`ServerOptions.CompletionHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.CompletionHandler). +If this field is set to a non-nil value, the server will advertise the +`completions` server capability, and use this handler to respond to completion +requests. + +```go +myCompletionHandler := func(_ context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { + // In a real application, you'd implement actual completion logic here. + // For this example, we return a fixed set of suggestions. + var suggestions []string + switch req.Params.Ref.Type { + case "ref/prompt": + suggestions = []string{"suggestion1", "suggestion2", "suggestion3"} + case "ref/resource": + suggestions = []string{"suggestion4", "suggestion5", "suggestion6"} + default: + return nil, fmt.Errorf("unrecognized content type %s", req.Params.Ref.Type) + } + + return &mcp.CompleteResult{ + Completion: mcp.CompletionResultDetails{ + HasMore: false, + Total: len(suggestions), + Values: suggestions, + }, + }, nil +} + +// Create the MCP Server instance and assign the handler. +// No server running, just showing the configuration. +_ = mcp.NewServer(&mcp.Implementation{Name: "server"}, &mcp.ServerOptions{ + CompletionHandler: myCompletionHandler, +}) +``` ### Logging diff --git a/examples/server/completion/main.go b/examples/server/completion/main.go index b0a991fd..5220b0ee 100644 --- a/examples/server/completion/main.go +++ b/examples/server/completion/main.go @@ -16,6 +16,7 @@ import ( // a CompletionHandler to an MCP Server's options. func main() { // Define your custom CompletionHandler logic. + // !+completionhandler myCompletionHandler := func(_ context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { // In a real application, you'd implement actual completion logic here. // For this example, we return a fixed set of suggestions. @@ -43,6 +44,7 @@ func main() { _ = mcp.NewServer(&mcp.Implementation{Name: "server"}, &mcp.ServerOptions{ CompletionHandler: myCompletionHandler, }) + // !-completionhandler log.Println("MCP Server instance created with a CompletionHandler assigned (but not running).") log.Println("This example demonstrates configuration, not live interaction.") diff --git a/internal/docs/server.src.md b/internal/docs/server.src.md index 32f80e9d..50619c60 100644 --- a/internal/docs/server.src.md +++ b/internal/docs/server.src.md @@ -36,11 +36,23 @@ Set [`ClientOptions.PromptListChangedHandler`](https://pkg.go.dev/github.com/mod ## Utilities - - ### Completion - +To support the +[completion](https://modelcontextprotocol.io/specification/2025-06-18/server/utilities/completion) +capability, the server needs a completion handler. + +**Client-side**: completion is called using the +[`ClientSession.Complete`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Complete) +method. + +**Server-side**: completion is enabled by setting +[`ServerOptions.CompletionHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.CompletionHandler). +If this field is set to a non-nil value, the server will advertise the +`completions` server capability, and use this handler to respond to completion +requests. + +%include ../../examples/server/completion/main.go completionhandler - ### Logging From 5985a7cbba537e4039224d85114d547608e99ade Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Fri, 19 Sep 2025 10:48:59 -0400 Subject: [PATCH 216/221] docs: document tool support (#498) Document tool support, and tweak the wording for prompts. For #442 --- docs/client.md | 5 +- docs/server.md | 211 ++++++++++++++++++++++++++++++++---- docs/troubleshooting.md | 9 +- internal/docs/server.src.md | 147 +++++++++++++++++++++---- mcp/client.go | 5 +- mcp/protocol.go | 19 +++- mcp/tool_example_test.go | 18 ++- 7 files changed, 358 insertions(+), 56 deletions(-) diff --git a/docs/client.md b/docs/client.md index cbc4db8f..2cbe082c 100644 --- a/docs/client.md +++ b/docs/client.md @@ -56,9 +56,12 @@ func Example_roots() { if _, err := s.Connect(ctx, t1, nil); err != nil { log.Fatal(err) } - if _, err := c.Connect(ctx, t2, nil); err != nil { + + clientSession, err := c.Connect(ctx, t2, nil) + if err != nil { log.Fatal(err) } + defer clientSession.Close() // ...and add a root. The server is notified about the change. c.AddRoots(&mcp.Root{URI: "file://b"}) diff --git a/docs/server.md b/docs/server.md index f59e2c7e..3de5c12b 100644 --- a/docs/server.md +++ b/docs/server.md @@ -11,25 +11,32 @@ ## Prompts -**Server-side**: -MCP servers can provide LLM prompt templates (called simply _prompts_) to clients. -Every prompt has a required name which identifies it, and a set of named arguments, which are strings. -Construct a prompt with a name and descriptions of its arguments. -Associated with each prompt is a handler that expands the template given values for its arguments. -Use [`Server.AddPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddPrompt) -to add a prompt along with its handler. -If `AddPrompt` is called before a server is connected, the server will have the `prompts` capability. -If all prompts are to be added after connection, set [`ServerOptions.HasPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasPrompts) -to advertise the capability. - -**Client-side**: -To list the server's prompts, call -Call [`ClientSession.Prompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Prompts) to get an iterator. -If needed, you can use the lower-level -[`ClientSession.ListPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ListPrompts) to list the server's prompts. -Call [`ClientSession.GetPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.GetPrompt) to retrieve a prompt by name, providing -arguments for expansion. -Set [`ClientOptions.PromptListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.PromptListChangedHandler) to be notified of changes in the list of prompts. +MCP servers can provide LLM prompt templates (called simply +[_prompts_](https://modelcontextprotocol.io/specification/2025-06-18/server/prompts)) +to clients. Every prompt has a required name which identifies it, and a set of +named arguments, which are strings. + +**Client-side**: To list the server's prompts, use the +[`ClientSession.Prompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Prompts) +iterator, or the lower-level +[`ClientSession.ListPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ListPrompts) +(see [pagination](#pagination) below). Set +[`ClientOptions.PromptListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.PromptListChangedHandler) +to be notified of changes in the list of prompts. + +Call +[`ClientSession.GetPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.GetPrompt) +to retrieve a prompt by name, providing arguments for expansion. + +**Server-side**: Use +[`Server.AddPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddPrompt) +to add a prompt to the server along with its handler. +The server will have the `prompts` capability if any prompt is added before the +server is connected to a client, or if +[`ServerOptions.HasPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasPrompts) +is explicitly set. When a prompt is added, any clients already connected to the +server will be notified via a `notifications/prompts/list_changed` +notification. ```go func Example_prompts() { @@ -73,6 +80,7 @@ func Example_prompts() { if err != nil { log.Fatal(err) } + defer cs.Close() // List the prompts. for p, err := range cs.Prompts(ctx, nil) { @@ -105,7 +113,170 @@ func Example_prompts() { ## Tools - +MCP servers can provide +[tools](https://modelcontextprotocol.io/specification/2025-06-18/server/tools) +to allow clients to interact with external systems or functionality. Tools are +effectively remote function calls, and the Go SDK provides mechanisms to bind +them to ordinary Go functions. + +**Client-side**: To list the server's tools, use the +[`ClientSession.Tools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Tools) +iterator, or the lower-level +[`ClientSession.ListTools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ListTools) +(see [pagination](#pagination)). Set +[`ClientOptions.ToolListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ToolListChangedHandler) +to be notified of changes in the list of tools. + +To call a tool, use +[`ClientSession.CallTool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.CallTool) +with `CallToolParams` holding the name and arguments of the tool to call. + +```go +res, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "my_tool", + Arguments: map[string]any{"name": "user"}, +}) +``` + +Arguments may be any value that can be marshaled to JSON. + +**Server-side**: the basic API for adding a tool is symmetrical with the API +for prompts or resources: +[`Server.AddTool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddTool) +adds a +[`Tool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Tool) to +the server along with its +[`ToolHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ToolHandler) +to handle it. The server will have the `tools` capability if any tool is added +before the server is connected to a client, or if +[`ServerOptions.HasTools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasPrompts) +is explicitly set. When a tool is added, any clients already connected to the +server will be notified via a `notifications/tools/list_changed` notification. + +However, the `Server.AddTool` API leaves it to the user to implement the tool +handler correctly according to the spec, providing very little out of the box. +In order to implement a tool, the user must do all of the following: + +- Provide a tool input and output schema. +- Validate the tool arguments against its input schema. +- Unmarshal the input schema into a Go value +- Execute the tool logic. +- Marshal the tool's structured output (if any) to JSON, and store it in the + result's `StructuredOutput` field as well as the unstructured `Content` field. +- Validate that output JSON against the tool's output schema. +- If any tool errors occurred, pack them into the unstructured content and set + `IsError` to `true.` + +For this reason, the SDK provides a generic +[`AddTool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#AddTool) +function that handles this for you. It can bind a tool to any function with the +following shape: + +```go +func(_ context.Context, request *CallToolRequest, input In) (result *CallToolResult, output Out, _ error) +``` + +This is like a `ToolHandler`, but with an extra arbitrary `In` input parameter, +and `Out` output parameter. + +Such a function can then be bound to the server using `AddTool`: + +```go +mcp.AddTool(server, &mcp.Tool{Name: "my_tool"}, handler) +``` + +This does the following automatically: + +- If `Tool.InputSchema` or `Tool.OutputSchema` are unset, the input and output + schemas are inferred from the `In` type, which must be a struct or map. + Optional `jsonschema` struct tags provide argument descriptions. +- Tool arguments are validated against the input schema. +- Tool arguments are marshaled into the `In` value. +- Tool output (the `Out` value) is marshaled into the result's + `StructuredOutput`, as well as the unstructured `Content`. +- Output is validated against the tool's output schema. +- If an ordinary error is returned, it is stored int the `CallToolResult` and + `IsError` is set to `true`. + +In fact, under ordinary circumstances, the user can ignore `CallToolRequest` +and `CallToolResult`. + +For a more realistic example, consider a tool that retrieves the weather: + +```go +type WeatherInput struct { + Location Location `json:"location" jsonschema:"user location"` + Days int `json:"days" jsonschema:"number of days to forecast"` +} + +type WeatherOutput struct { + Summary string `json:"summary" jsonschema:"a summary of the weather forecast"` + Confidence Probability `json:"confidence" jsonschema:"confidence, between 0 and 1"` + AsOf time.Time `json:"asOf" jsonschema:"the time the weather was computed"` + DailyForecast []Forecast `json:"dailyForecast" jsonschema:"the daily forecast"` + Source string `json:"source,omitempty" jsonschema:"the organization providing the weather forecast"` +} + +func WeatherTool(ctx context.Context, req *mcp.CallToolRequest, in WeatherInput) (*mcp.CallToolResult, WeatherOutput, error) { + perfectWeather := WeatherOutput{ + Summary: "perfect", + Confidence: 1.0, + AsOf: time.Now(), + } + for range in.Days { + perfectWeather.DailyForecast = append(perfectWeather.DailyForecast, Forecast{ + Forecast: "another perfect day", + Type: Sunny, + Rain: 0.0, + High: 72.0, + Low: 72.0, + }) + } + return nil, perfectWeather, nil +} +``` + +In this case, we want to customize part of the inferred schema, though we can +still infer the rest. Since we want to control the inference ourselves, we set +the `Tool.InputSchema` explicitly: + +```go +// Distinguished Go types allow custom schemas to be reused during inference. +customSchemas := map[any]*jsonschema.Schema{ + Probability(0): {Type: "number", Minimum: jsonschema.Ptr(0.0), Maximum: jsonschema.Ptr(1.0)}, + WeatherType(""): {Type: "string", Enum: []any{Sunny, PartlyCloudy, Cloudy, Rainy, Snowy}}, +} +opts := &jsonschema.ForOptions{TypeSchemas: customSchemas} +in, err := jsonschema.For[WeatherInput](opts) +if err != nil { + log.Fatal(err) +} + +// Furthermore, we can tweak the inferred schema, in this case limiting +// forecasts to 0-10 days. +daysSchema := in.Properties["days"] +daysSchema.Minimum = jsonschema.Ptr(0.0) +daysSchema.Maximum = jsonschema.Ptr(10.0) + +// Output schema inference can reuse our custom schemas from input inference. +out, err := jsonschema.For[WeatherOutput](opts) +if err != nil { + log.Fatal(err) +} + +// Now add our tool to a server. Since we've customized the schemas, we need +// to override the default schema inference. +server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) +mcp.AddTool(server, &mcp.Tool{ + Name: "weather", + InputSchema: in, + OutputSchema: out, +}, WeatherTool) +``` + +_See [mcp/tool_example_test.go](../mcp/tool_example_test.go) for the full +example, or [examples/server/toolschemas](examples/server/toolschemas/main.go) +for more examples of customizing tool schemas._ ## Utilities diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 0f990edc..38410ad5 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -29,16 +29,21 @@ func ExampleLoggingTransport() { ctx := context.Background() t1, t2 := mcp.NewInMemoryTransports() server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) - if _, err := server.Connect(ctx, t1, nil); err != nil { + serverSession, err := server.Connect(ctx, t1, nil) + if err != nil { log.Fatal(err) } + defer serverSession.Wait() client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) var b bytes.Buffer logTransport := &mcp.LoggingTransport{Transport: t2, Writer: &b} - if _, err := client.Connect(ctx, logTransport, nil); err != nil { + clientSession, err := client.Connect(ctx, logTransport, nil) + if err != nil { log.Fatal(err) } + defer clientSession.Close() + // Sort for stability: reads are concurrent to writes. for _, line := range slices.Sorted(strings.SplitSeq(b.String(), "\n")) { fmt.Println(line) diff --git a/internal/docs/server.src.md b/internal/docs/server.src.md index 50619c60..0ef476cd 100644 --- a/internal/docs/server.src.md +++ b/internal/docs/server.src.md @@ -4,25 +4,32 @@ ## Prompts -**Server-side**: -MCP servers can provide LLM prompt templates (called simply _prompts_) to clients. -Every prompt has a required name which identifies it, and a set of named arguments, which are strings. -Construct a prompt with a name and descriptions of its arguments. -Associated with each prompt is a handler that expands the template given values for its arguments. -Use [`Server.AddPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddPrompt) -to add a prompt along with its handler. -If `AddPrompt` is called before a server is connected, the server will have the `prompts` capability. -If all prompts are to be added after connection, set [`ServerOptions.HasPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasPrompts) -to advertise the capability. - -**Client-side**: -To list the server's prompts, call -Call [`ClientSession.Prompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Prompts) to get an iterator. -If needed, you can use the lower-level -[`ClientSession.ListPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ListPrompts) to list the server's prompts. -Call [`ClientSession.GetPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.GetPrompt) to retrieve a prompt by name, providing -arguments for expansion. -Set [`ClientOptions.PromptListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.PromptListChangedHandler) to be notified of changes in the list of prompts. +MCP servers can provide LLM prompt templates (called simply +[_prompts_](https://modelcontextprotocol.io/specification/2025-06-18/server/prompts)) +to clients. Every prompt has a required name which identifies it, and a set of +named arguments, which are strings. + +**Client-side**: To list the server's prompts, use the +[`ClientSession.Prompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Prompts) +iterator, or the lower-level +[`ClientSession.ListPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ListPrompts) +(see [pagination](#pagination) below). Set +[`ClientOptions.PromptListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.PromptListChangedHandler) +to be notified of changes in the list of prompts. + +Call +[`ClientSession.GetPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.GetPrompt) +to retrieve a prompt by name, providing arguments for expansion. + +**Server-side**: Use +[`Server.AddPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddPrompt) +to add a prompt to the server along with its handler. +The server will have the `prompts` capability if any prompt is added before the +server is connected to a client, or if +[`ServerOptions.HasPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasPrompts) +is explicitly set. When a prompt is added, any clients already connected to the +server will be notified via a `notifications/prompts/list_changed` +notification. %include ../../mcp/server_example_test.go prompts - @@ -32,7 +39,107 @@ Set [`ClientOptions.PromptListChangedHandler`](https://pkg.go.dev/github.com/mod ## Tools - +MCP servers can provide +[tools](https://modelcontextprotocol.io/specification/2025-06-18/server/tools) +to allow clients to interact with external systems or functionality. Tools are +effectively remote function calls, and the Go SDK provides mechanisms to bind +them to ordinary Go functions. + +**Client-side**: To list the server's tools, use the +[`ClientSession.Tools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Tools) +iterator, or the lower-level +[`ClientSession.ListTools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ListTools) +(see [pagination](#pagination)). Set +[`ClientOptions.ToolListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ToolListChangedHandler) +to be notified of changes in the list of tools. + +To call a tool, use +[`ClientSession.CallTool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.CallTool) +with `CallToolParams` holding the name and arguments of the tool to call. + +```go +res, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "my_tool", + Arguments: map[string]any{"name": "user"}, +}) +``` + +Arguments may be any value that can be marshaled to JSON. + +**Server-side**: the basic API for adding a tool is symmetrical with the API +for prompts or resources: +[`Server.AddTool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddTool) +adds a +[`Tool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Tool) to +the server along with its +[`ToolHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ToolHandler) +to handle it. The server will have the `tools` capability if any tool is added +before the server is connected to a client, or if +[`ServerOptions.HasTools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasPrompts) +is explicitly set. When a tool is added, any clients already connected to the +server will be notified via a `notifications/tools/list_changed` notification. + +However, the `Server.AddTool` API leaves it to the user to implement the tool +handler correctly according to the spec, providing very little out of the box. +In order to implement a tool, the user must do all of the following: + +- Provide a tool input and output schema. +- Validate the tool arguments against its input schema. +- Unmarshal the input schema into a Go value +- Execute the tool logic. +- Marshal the tool's structured output (if any) to JSON, and store it in the + result's `StructuredOutput` field as well as the unstructured `Content` field. +- Validate that output JSON against the tool's output schema. +- If any tool errors occurred, pack them into the unstructured content and set + `IsError` to `true.` + +For this reason, the SDK provides a generic +[`AddTool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#AddTool) +function that handles this for you. It can bind a tool to any function with the +following shape: + +```go +func(_ context.Context, request *CallToolRequest, input In) (result *CallToolResult, output Out, _ error) +``` + +This is like a `ToolHandler`, but with an extra arbitrary `In` input parameter, +and `Out` output parameter. + +Such a function can then be bound to the server using `AddTool`: + +```go +mcp.AddTool(server, &mcp.Tool{Name: "my_tool"}, handler) +``` + +This does the following automatically: + +- If `Tool.InputSchema` or `Tool.OutputSchema` are unset, the input and output + schemas are inferred from the `In` type, which must be a struct or map. + Optional `jsonschema` struct tags provide argument descriptions. +- Tool arguments are validated against the input schema. +- Tool arguments are marshaled into the `In` value. +- Tool output (the `Out` value) is marshaled into the result's + `StructuredOutput`, as well as the unstructured `Content`. +- Output is validated against the tool's output schema. +- If an ordinary error is returned, it is stored int the `CallToolResult` and + `IsError` is set to `true`. + +In fact, under ordinary circumstances, the user can ignore `CallToolRequest` +and `CallToolResult`. + +For a more realistic example, consider a tool that retrieves the weather: + +%include ../../mcp/tool_example_test.go weathertool - + +In this case, we want to customize part of the inferred schema, though we can +still infer the rest. Since we want to control the inference ourselves, we set +the `Tool.InputSchema` explicitly: + +%include ../../mcp/tool_example_test.go customschemas - + +_See [mcp/tool_example_test.go](../mcp/tool_example_test.go) for the full +example, or [examples/server/toolschemas](examples/server/toolschemas/main.go) +for more examples of customizing tool schemas._ ## Utilities diff --git a/mcp/client.go b/mcp/client.go index 822566de..dea3e854 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -571,8 +571,9 @@ func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) return handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params))) } -// CallTool calls the tool with the given name and arguments. -// The arguments can be any value that marshals into a JSON object. +// CallTool calls the tool with the given parameters. +// +// The params.Arguments can be any value that marshals into a JSON object. func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (*CallToolResult, error) { if params == nil { params = new(CallToolParams) diff --git a/mcp/protocol.go b/mcp/protocol.go index 3e3c544e..f3f23f58 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -42,11 +42,14 @@ type Annotations struct { // CallToolParams is used by clients to call a tool. type CallToolParams struct { - // This property is reserved by the protocol to allow clients and servers to + // Meta is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. - Meta `json:"_meta,omitempty"` - Name string `json:"name"` - Arguments any `json:"arguments,omitempty"` + Meta `json:"_meta,omitempty"` + // Name is the name of the tool to call. + Name string `json:"name"` + // Arguments holds the tool arguments. It can hold any value that can be + // marshaled to JSON. + Arguments any `json:"arguments,omitempty"` } // CallToolParamsRaw is passed to tool handlers on the server. Its arguments @@ -55,8 +58,12 @@ type CallToolParams struct { type CallToolParamsRaw struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. - Meta `json:"_meta,omitempty"` - Name string `json:"name"` + Meta `json:"_meta,omitempty"` + // Name is the name of the tool being called. + Name string `json:"name"` + // Arguments is the raw arguments received over the wire from the client. It + // is the responsibility of the tool handler to unmarshal and validate the + // Arguments (see [AddTool]). Arguments json.RawMessage `json:"arguments,omitempty"` } diff --git a/mcp/tool_example_test.go b/mcp/tool_example_test.go index 888309bc..8f3fbbe6 100644 --- a/mcp/tool_example_test.go +++ b/mcp/tool_example_test.go @@ -83,11 +83,6 @@ func ExampleAddTool_customMarshalling() { // } } -type WeatherInput struct { - Location Location `json:"location" jsonschema:"user location"` - Days int `json:"days" jsonschema:"number of days to forecast"` -} - type Location struct { Name string `json:"name"` Latitude *float64 `json:"latitude,omitempty"` @@ -114,6 +109,13 @@ const ( type Probability float64 +// !+weathertool + +type WeatherInput struct { + Location Location `json:"location" jsonschema:"user location"` + Days int `json:"days" jsonschema:"number of days to forecast"` +} + type WeatherOutput struct { Summary string `json:"summary" jsonschema:"a summary of the weather forecast"` Confidence Probability `json:"confidence" jsonschema:"confidence, between 0 and 1"` @@ -140,11 +142,15 @@ func WeatherTool(ctx context.Context, req *mcp.CallToolRequest, in WeatherInput) return nil, perfectWeather, nil } +// !-weathertool + func ExampleAddTool_complexSchema() { // This example demonstrates a tool with a more 'realistic' input and output // schema. We use a combination of techniques to tune our input and output // schemas. + // !+customschemas + // Distinguished Go types allow custom schemas to be reused during inference. customSchemas := map[any]*jsonschema.Schema{ Probability(0): {Type: "number", Minimum: jsonschema.Ptr(0.0), Maximum: jsonschema.Ptr(1.0)}, @@ -177,6 +183,8 @@ func ExampleAddTool_complexSchema() { OutputSchema: out, }, WeatherTool) + // !-customschemas + ctx := context.Background() session, err := connect(ctx, server) // create an in-memory connection if err != nil { From 4fa725857ca5374be0459bb03805c58f10d6eaca Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 19 Sep 2025 11:46:50 -0400 Subject: [PATCH 217/221] docs: document resources (#505) --- docs/server.md | 124 +++++++++++++++++++++++++++++++++++- internal/docs/server.src.md | 54 +++++++++++++++- mcp/server_example_test.go | 73 +++++++++++++++++++++ 3 files changed, 249 insertions(+), 2 deletions(-) diff --git a/docs/server.md b/docs/server.md index 3de5c12b..e16f2fc2 100644 --- a/docs/server.md +++ b/docs/server.md @@ -109,7 +109,129 @@ func Example_prompts() { ## Resources - +In MCP terms, a _resource_ is some data referenced by a URI. +MCP servers can serve resources to clients. +They can register resources individually, or register a _resource template_ +that uses a URI pattern to describe a collection of resources. + + +**Client-side**: +Call [`ClientSession.ReadResource`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ReadResource) +to read a resource. +The SDK ensures that a read succeeds only if the URI matches a registered resource exactly, +or matches the URI pattern of a resource template. + +To list a server's resources and resource templates, use the +[`ClientSession.Resources`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Resources) +and +[`ClientSession.ResourceTemplates`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ResourceTemplates) +iterators, or the lower-level `ListXXX` calls (see [pagination](#pagination)). +Set +[`ClientOptions.ResourceListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ResourceListChangedHandler) +to be notified of changes in the lists of resources or resource templates. + +Clients can be notified when the contents of a resource changes by subscribing to the resource's URI. +Call +[`ClientSession.Subscribe`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Subscribe) +to subscribe to a resource +and +[`ClientSession.Unsubscribe`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Unsubscribe) +to unsubscribe. +Set +[`ClientOptions.ResourceUpdatedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ResourceUpdatedHandler) +to be notified of changes to subscribed resources. + +**Server-side**: +Use +[`Server.AddResource`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddResource) +or +[`Server.AddResourceTemplate`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddResourceTemplate) +to add a resource or resource template to the server along with its handler. +A +[`ResourceHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ResourceHandler) +maps a URI to the contents of a resource, which can include text, binary data, +or both. +If `AddResource` or `AddResourceTemplate` is called before a server is connected, the server will have the +`resources` capability. +The server will have the `resources` capability if any resource or resource template is added before the +server is connected to a client, or if +[`ServerOptions.HasResources`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasResources) +is explicitly set. When a prompt is added, any clients already connected to the +server will be notified via a `notifications/resources/list_changed` +notification. + + +```go +func Example_resources() { + ctx := context.Background() + + resources := map[string]string{ + "file:///a": "a", + "file:///dir/x": "x", + "file:///dir/y": "y", + } + + handler := func(_ context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + uri := req.Params.URI + c, ok := resources[uri] + if !ok { + return nil, mcp.ResourceNotFoundError(uri) + } + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{URI: uri, Text: c}}, + }, nil + } + + // Create a server with a single resource. + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + s.AddResource(&mcp.Resource{URI: "file:///a"}, handler) + s.AddResourceTemplate(&mcp.ResourceTemplate{URITemplate: "file:///dir/{f}"}, handler) + + // Create a client. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + + // Connect the server and client. + t1, t2 := mcp.NewInMemoryTransports() + if _, err := s.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + cs, err := c.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + // List resources and resource templates. + for r, err := range cs.Resources(ctx, nil) { + if err != nil { + log.Fatal(err) + } + fmt.Println(r.URI) + } + for r, err := range cs.ResourceTemplates(ctx, nil) { + if err != nil { + log.Fatal(err) + } + fmt.Println(r.URITemplate) + } + + // Read resources. + for _, path := range []string{"a", "dir/x", "b"} { + res, err := cs.ReadResource(ctx, &mcp.ReadResourceParams{URI: "file:///" + path}) + if err != nil { + fmt.Println(err) + } else { + fmt.Println(res.Contents[0].Text) + } + } + // Output: + // file:///a + // file:///dir/{f} + // a + // x + // calling "resources/read": Resource not found +} +``` ## Tools diff --git a/internal/docs/server.src.md b/internal/docs/server.src.md index 0ef476cd..345036b3 100644 --- a/internal/docs/server.src.md +++ b/internal/docs/server.src.md @@ -35,7 +35,59 @@ notification. ## Resources - +In MCP terms, a _resource_ is some data referenced by a URI. +MCP servers can serve resources to clients. +They can register resources individually, or register a _resource template_ +that uses a URI pattern to describe a collection of resources. + + +**Client-side**: +Call [`ClientSession.ReadResource`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ReadResource) +to read a resource. +The SDK ensures that a read succeeds only if the URI matches a registered resource exactly, +or matches the URI pattern of a resource template. + +To list a server's resources and resource templates, use the +[`ClientSession.Resources`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Resources) +and +[`ClientSession.ResourceTemplates`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ResourceTemplates) +iterators, or the lower-level `ListXXX` calls (see [pagination](#pagination)). +Set +[`ClientOptions.ResourceListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ResourceListChangedHandler) +to be notified of changes in the lists of resources or resource templates. + +Clients can be notified when the contents of a resource changes by subscribing to the resource's URI. +Call +[`ClientSession.Subscribe`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Subscribe) +to subscribe to a resource +and +[`ClientSession.Unsubscribe`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Unsubscribe) +to unsubscribe. +Set +[`ClientOptions.ResourceUpdatedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ResourceUpdatedHandler) +to be notified of changes to subscribed resources. + +**Server-side**: +Use +[`Server.AddResource`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddResource) +or +[`Server.AddResourceTemplate`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddResourceTemplate) +to add a resource or resource template to the server along with its handler. +A +[`ResourceHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ResourceHandler) +maps a URI to the contents of a resource, which can include text, binary data, +or both. +If `AddResource` or `AddResourceTemplate` is called before a server is connected, the server will have the +`resources` capability. +The server will have the `resources` capability if any resource or resource template is added before the +server is connected to a client, or if +[`ServerOptions.HasResources`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasResources) +is explicitly set. When a prompt is added, any clients already connected to the +server will be notified via a `notifications/resources/list_changed` +notification. + + +%include ../../mcp/server_example_test.go resources - ## Tools diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index 16f19e20..d9dd1685 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -82,3 +82,76 @@ func Example_prompts() { } // !-prompts + +// !+resources +func Example_resources() { + ctx := context.Background() + + resources := map[string]string{ + "file:///a": "a", + "file:///dir/x": "x", + "file:///dir/y": "y", + } + + handler := func(_ context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + uri := req.Params.URI + c, ok := resources[uri] + if !ok { + return nil, mcp.ResourceNotFoundError(uri) + } + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{URI: uri, Text: c}}, + }, nil + } + + // Create a server with a single resource. + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + s.AddResource(&mcp.Resource{URI: "file:///a"}, handler) + s.AddResourceTemplate(&mcp.ResourceTemplate{URITemplate: "file:///dir/{f}"}, handler) + + // Create a client. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + + // Connect the server and client. + t1, t2 := mcp.NewInMemoryTransports() + if _, err := s.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + cs, err := c.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + // List resources and resource templates. + for r, err := range cs.Resources(ctx, nil) { + if err != nil { + log.Fatal(err) + } + fmt.Println(r.URI) + } + for r, err := range cs.ResourceTemplates(ctx, nil) { + if err != nil { + log.Fatal(err) + } + fmt.Println(r.URITemplate) + } + + // Read resources. + for _, path := range []string{"a", "dir/x", "b"} { + res, err := cs.ReadResource(ctx, &mcp.ReadResourceParams{URI: "file:///" + path}) + if err != nil { + fmt.Println(err) + } else { + fmt.Println(res.Contents[0].Text) + } + } + // Output: + // file:///a + // file:///dir/{f} + // a + // x + // calling "resources/read": Resource not found +} + +// !-resources From 8b6391b17301a26f59bfa1750ea1b46613cf3f2c Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 19 Sep 2025 11:50:42 -0400 Subject: [PATCH 218/221] docs: explain how to get TokenInfo (#502) --- docs/protocol.md | 7 +++++++ internal/docs/protocol.src.md | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/docs/protocol.md b/docs/protocol.md index e3b9e8c2..c87e522b 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -254,6 +254,13 @@ If [`RequireBearerTokenOptions.ResourceMetadataURL`](https://pkg.go.dev/github.c the middleware function sets the WWW-Authenticate header as required by the [Protected Resource Metadata spec](https://datatracker.ietf.org/doc/html/rfc9728). +Server handlers, such as tool handlers, can obtain the `TokenInfo` returned by the `TokenVerifier` +from `req.Extra.TokenInfo`, where `req` is the handler's request. (For example, a +[`CallToolRequest`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#CallToolRequest).) +HTTP handlers wrapped by the `RequireBearerToken` middleware can obtain the `TokenInfo` from the context +with [`auth.TokenInfoFromContext`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenInfoFromContext). + + The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/tree/main/examples/server/auth-middleware) shows how to implement authorization for both JWT tokens and API keys. ### Client diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index 0abd5aa8..88c023cc 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -181,6 +181,13 @@ If [`RequireBearerTokenOptions.ResourceMetadataURL`](https://pkg.go.dev/github.c the middleware function sets the WWW-Authenticate header as required by the [Protected Resource Metadata spec](https://datatracker.ietf.org/doc/html/rfc9728). +Server handlers, such as tool handlers, can obtain the `TokenInfo` returned by the `TokenVerifier` +from `req.Extra.TokenInfo`, where `req` is the handler's request. (For example, a +[`CallToolRequest`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#CallToolRequest).) +HTTP handlers wrapped by the `RequireBearerToken` middleware can obtain the `TokenInfo` from the context +with [`auth.TokenInfoFromContext`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenInfoFromContext). + + The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/tree/main/examples/server/auth-middleware) shows how to implement authorization for both JWT tokens and API keys. ### Client From 84a568bf895e3e962818638acfe3b7d99d2c1de6 Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Fri, 19 Sep 2025 11:54:10 -0400 Subject: [PATCH 219/221] readme: update README for v0.6.0 (#506) Also significantly lessen our warnings, and note that this is a release candidate, rewrite Acknowledgements to be Acknowledgements / Alternatives, since other MCP SDKs continue to be actively developed. --- README.md | 55 +++++++++++++++++------------------ docs/README.md | 4 +-- internal/docs/README.src.md | 4 +-- internal/readme/README.src.md | 55 +++++++++++++++++------------------ 4 files changed, 56 insertions(+), 62 deletions(-) diff --git a/README.md b/README.md index ab989028..9263e46b 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,26 @@ -# MCP Go SDK v0.5.0 +# MCP Go SDK v0.6.0 [![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/modelcontextprotocol/go-sdk) ***BREAKING CHANGES*** -This version contains breaking changes. +This version contains minor breaking changes. See the [release notes]( -https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.5.0) for details. +https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.6.0) for details. [![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) -This repository contains an unreleased implementation of the official Go -software development kit (SDK) for the Model Context Protocol (MCP). +This repository contains an implementation of the official Go software +development kit (SDK) for the Model Context Protocol (MCP). -> [!WARNING] -> The SDK is not yet at v1.0.0 and may still be subject to incompatible API -> changes. We aim to tag v1.0.0 in September, 2025. See -> https://github.com/modelcontextprotocol/go-sdk/issues/328 for details. +> [!IMPORTANT] +> The SDK is in release-candidate state, and is going to be tagged v1.0.0 +> soon (see https://github.com/modelcontextprotocol/go-sdk/issues/328). +> We do not anticipate significant API changes or instability. Please use it +> and [file issues](https://github.com/modelcontextprotocol/go-sdk/issues/new/choose). -## Package documentation +## Package / Feature documentation The SDK consists of several importable packages: @@ -33,11 +34,13 @@ The SDK consists of several importable packages: - The [`github.com/modelcontextprotocol/go-sdk/auth`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth) package provides some primitives for supporting oauth. - - The [`github.com/modelcontextprotocol/go-sdk/oauthex`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/oauthex) package provides extensions to the OAuth protocol, such as ProtectedResourceMetadata. +The SDK endeavors to implement the full MCP spec. The [`docs/`](/docs/) directory +contains feature documentation, mapping the MCP spec to the packages above. + ## Getting started To get started creating an MCP server, create an `mcp.Server` instance, add @@ -78,9 +81,9 @@ func main() { } ``` -To communicate with that server, we can similarly create an `mcp.Client` and -connect it to the corresponding server, by running the server command and -communicating over its stdin/stdout: +To communicate with that server, create an `mcp.Client` and connect it to the +corresponding server, by running the server command and communicating over its +stdin/stdout: ```go package main @@ -128,24 +131,18 @@ func main() { The [`examples/`](/examples/) directory contains more example clients and servers. -## Design - -The design doc for this SDK is at [design.md](./design/design.md), which was -initially reviewed at -[modelcontextprotocol/discussions/364](https://github.com/orgs/modelcontextprotocol/discussions/364). +## Contributing -Further design discussion should occur in -[issues](https://github.com/modelcontextprotocol/go-sdk/issues) (for concrete -proposals) or -[discussions](https://github.com/modelcontextprotocol/go-sdk/discussions) for -open-ended discussion. See [CONTRIBUTING.md](/CONTRIBUTING.md) for details. +We welcome contributions to the SDK! Please see See +[CONTRIBUTING.md](/CONTRIBUTING.md) for details of how to contribute. -## Acknowledgements +## Acknowledgements / Alternatives -Several existing Go MCP SDKs inspired the development and design of this -official SDK, notably [mcp-go](https://github.com/mark3labs/mcp-go), authored -by Ed Zynda. We are grateful to Ed as well as the other contributors to mcp-go, -and to authors and contributors of other SDKs such as +Several third party Go MCP SDKs inspired the development and design of this +official SDK, and continue to be viable alternatives, notably +[mcp-go](https://github.com/mark3labs/mcp-go), originally authored by Ed Zynda. +We are grateful to Ed as well as the other contributors to mcp-go, and to +authors and contributors of other SDKs such as [mcp-golang](https://github.com/metoro-io/mcp-golang) and [go-mcp](https://github.com/ThinkInAIXYZ/go-mcp). Thanks to their work, there is a thriving ecosystem of Go MCP clients and servers. diff --git a/docs/README.md b/docs/README.md index 81e7f5f7..b4268c85 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,9 +1,9 @@ -These docs are a work-in-progress. - # Features These docs mirror the [official MCP spec](https://modelcontextprotocol.io/specification/2025-06-18). +Use the index below to learn how the SDK implements a particular aspect of the +protocol. ## Base Protocol diff --git a/internal/docs/README.src.md b/internal/docs/README.src.md index fb600df3..b252f943 100644 --- a/internal/docs/README.src.md +++ b/internal/docs/README.src.md @@ -1,8 +1,8 @@ -These docs are a work-in-progress. - # Features These docs mirror the [official MCP spec](https://modelcontextprotocol.io/specification/2025-06-18). +Use the index below to learn how the SDK implements a particular aspect of the +protocol. ## Base Protocol diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index acd08e41..10ce9b67 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -1,24 +1,25 @@ -# MCP Go SDK v0.5.0 +# MCP Go SDK v0.6.0 [![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/modelcontextprotocol/go-sdk) ***BREAKING CHANGES*** -This version contains breaking changes. +This version contains minor breaking changes. See the [release notes]( -https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.5.0) for details. +https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.6.0) for details. [![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) -This repository contains an unreleased implementation of the official Go -software development kit (SDK) for the Model Context Protocol (MCP). +This repository contains an implementation of the official Go software +development kit (SDK) for the Model Context Protocol (MCP). -> [!WARNING] -> The SDK is not yet at v1.0.0 and may still be subject to incompatible API -> changes. We aim to tag v1.0.0 in September, 2025. See -> https://github.com/modelcontextprotocol/go-sdk/issues/328 for details. +> [!IMPORTANT] +> The SDK is in release-candidate state, and is going to be tagged v1.0.0 +> soon (see https://github.com/modelcontextprotocol/go-sdk/issues/328). +> We do not anticipate significant API changes or instability. Please use it +> and [file issues](https://github.com/modelcontextprotocol/go-sdk/issues/new/choose). -## Package documentation +## Package / Feature documentation The SDK consists of several importable packages: @@ -32,11 +33,13 @@ The SDK consists of several importable packages: - The [`github.com/modelcontextprotocol/go-sdk/auth`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth) package provides some primitives for supporting oauth. - - The [`github.com/modelcontextprotocol/go-sdk/oauthex`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/oauthex) package provides extensions to the OAuth protocol, such as ProtectedResourceMetadata. +The SDK endeavors to implement the full MCP spec. The [`docs/`](/docs/) directory +contains feature documentation, mapping the MCP spec to the packages above. + ## Getting started To get started creating an MCP server, create an `mcp.Server` instance, add @@ -46,33 +49,27 @@ stdin/stdout: %include server/server.go - -To communicate with that server, we can similarly create an `mcp.Client` and -connect it to the corresponding server, by running the server command and -communicating over its stdin/stdout: +To communicate with that server, create an `mcp.Client` and connect it to the +corresponding server, by running the server command and communicating over its +stdin/stdout: %include client/client.go - The [`examples/`](/examples/) directory contains more example clients and servers. -## Design - -The design doc for this SDK is at [design.md](./design/design.md), which was -initially reviewed at -[modelcontextprotocol/discussions/364](https://github.com/orgs/modelcontextprotocol/discussions/364). +## Contributing -Further design discussion should occur in -[issues](https://github.com/modelcontextprotocol/go-sdk/issues) (for concrete -proposals) or -[discussions](https://github.com/modelcontextprotocol/go-sdk/discussions) for -open-ended discussion. See [CONTRIBUTING.md](/CONTRIBUTING.md) for details. +We welcome contributions to the SDK! Please see See +[CONTRIBUTING.md](/CONTRIBUTING.md) for details of how to contribute. -## Acknowledgements +## Acknowledgements / Alternatives -Several existing Go MCP SDKs inspired the development and design of this -official SDK, notably [mcp-go](https://github.com/mark3labs/mcp-go), authored -by Ed Zynda. We are grateful to Ed as well as the other contributors to mcp-go, -and to authors and contributors of other SDKs such as +Several third party Go MCP SDKs inspired the development and design of this +official SDK, and continue to be viable alternatives, notably +[mcp-go](https://github.com/mark3labs/mcp-go), originally authored by Ed Zynda. +We are grateful to Ed as well as the other contributors to mcp-go, and to +authors and contributors of other SDKs such as [mcp-golang](https://github.com/metoro-io/mcp-golang) and [go-mcp](https://github.com/ThinkInAIXYZ/go-mcp). Thanks to their work, there is a thriving ecosystem of Go MCP clients and servers. From 4e9a387d046f77bf44c5c055599b8d313e3b88c3 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 19 Sep 2025 11:58:58 -0400 Subject: [PATCH 220/221] doc: document logging (#497) For #442. --- docs/server.md | 85 ++++++++++++++++++++++++++++++++++++- internal/docs/server.src.md | 28 +++++++++++- mcp/logging.go | 1 + mcp/server_example_test.go | 63 +++++++++++++++++++++++++++ 4 files changed, 175 insertions(+), 2 deletions(-) diff --git a/docs/server.md b/docs/server.md index e16f2fc2..d6656ee9 100644 --- a/docs/server.md +++ b/docs/server.md @@ -450,7 +450,90 @@ _ = mcp.NewServer(&mcp.Implementation{Name: "server"}, &mcp.ServerOptions{ ### Logging - +MCP servers can send logging messages to MCP clients. +(This form of logging is distinct from server-side logging, where the +server produces logs that remain server-side, for use by server maintainers.) + +**Server-side**: +The minimum log level is part of the server state. +For stateful sessions, there is no default log level: no log messages will be sent +until the client calls `SetLevel` (see below). +For stateful sessions, the level defaults to "info". + +[`ServerSession.Log`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.Log) is the low-level way for servers to log to clients. +It sends a logging notification to the client if the level of the message +is at least the minimum log level. + +For a simpler API, use [`NewLoggingHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#NewLoggingHandler) to obtain a [`slog.Handler`](https://pkg.go.dev/log/slog#Handler). +By setting [`LoggingHandlerOptions.MinInterval`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#LoggingHandlerOptions.MinInterval), the handler can be rate-limited +to avoid spamming clients with too many messages. + +Servers always report the logging capability. + + +**Client-side**: +Set [`ClientOptions.LoggingMessageHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.LoggingMessageHandler) to receive log messages. + +Call [`ClientSession.SetLevel`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.SetLevel) to change the log level for a session. + +```go +func Example_logging() { + ctx := context.Background() + + // Create a server. + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + + // Create a client that displays log messages. + done := make(chan struct{}) // solely for the example + var nmsgs atomic.Int32 + c := mcp.NewClient( + &mcp.Implementation{Name: "client", Version: "v0.0.1"}, + &mcp.ClientOptions{ + LoggingMessageHandler: func(_ context.Context, r *mcp.LoggingMessageRequest) { + m := r.Params.Data.(map[string]any) + fmt.Println(m["msg"], m["value"]) + if nmsgs.Add(1) == 2 { // number depends on logger calls below + close(done) + } + }, + }) + + // Connect the server and client. + t1, t2 := mcp.NewInMemoryTransports() + ss, err := s.Connect(ctx, t1, nil) + if err != nil { + log.Fatal(err) + } + defer ss.Close() + cs, err := c.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + // Set the minimum log level to "info". + if err := cs.SetLoggingLevel(ctx, &mcp.SetLoggingLevelParams{Level: "info"}); err != nil { + log.Fatal(err) + } + + // Get a slog.Logger for the server session. + logger := slog.New(mcp.NewLoggingHandler(ss, nil)) + + // Log some things. + logger.Info("info shows up", "value", 1) + logger.Debug("debug doesn't show up", "value", 2) + logger.Warn("warn shows up", "value", 3) + + // Wait for them to arrive on the client. + // In a real application, the log messages would appear asynchronously + // while other work was happening. + <-done + + // Output: + // info shows up 1 + // warn shows up 3 +} +``` ### Pagination diff --git a/internal/docs/server.src.md b/internal/docs/server.src.md index 345036b3..c09ba63e 100644 --- a/internal/docs/server.src.md +++ b/internal/docs/server.src.md @@ -215,7 +215,33 @@ requests. ### Logging - +MCP servers can send logging messages to MCP clients. +(This form of logging is distinct from server-side logging, where the +server produces logs that remain server-side, for use by server maintainers.) + +**Server-side**: +The minimum log level is part of the server state. +For stateful sessions, there is no default log level: no log messages will be sent +until the client calls `SetLevel` (see below). +For stateful sessions, the level defaults to "info". + +[`ServerSession.Log`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.Log) is the low-level way for servers to log to clients. +It sends a logging notification to the client if the level of the message +is at least the minimum log level. + +For a simpler API, use [`NewLoggingHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#NewLoggingHandler) to obtain a [`slog.Handler`](https://pkg.go.dev/log/slog#Handler). +By setting [`LoggingHandlerOptions.MinInterval`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#LoggingHandlerOptions.MinInterval), the handler can be rate-limited +to avoid spamming clients with too many messages. + +Servers always report the logging capability. + + +**Client-side**: +Set [`ClientOptions.LoggingMessageHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.LoggingMessageHandler) to receive log messages. + +Call [`ClientSession.SetLevel`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.SetLevel) to change the log level for a session. + +%include ../../mcp/server_example_test.go logging - ### Pagination diff --git a/mcp/logging.go b/mcp/logging.go index 4d33097a..b3186a96 100644 --- a/mcp/logging.go +++ b/mcp/logging.go @@ -70,6 +70,7 @@ type LoggingHandlerOptions struct { // The value for the "logger" field of logging notifications. LoggerName string // Limits the rate at which log messages are sent. + // Excess messages are dropped. // If zero, there is no rate limiting. MinInterval time.Duration } diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index d9dd1685..db04920b 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -8,6 +8,8 @@ import ( "context" "fmt" "log" + "log/slog" + "sync/atomic" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -83,6 +85,67 @@ func Example_prompts() { // !-prompts +// !+logging + +func Example_logging() { + ctx := context.Background() + + // Create a server. + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + + // Create a client that displays log messages. + done := make(chan struct{}) // solely for the example + var nmsgs atomic.Int32 + c := mcp.NewClient( + &mcp.Implementation{Name: "client", Version: "v0.0.1"}, + &mcp.ClientOptions{ + LoggingMessageHandler: func(_ context.Context, r *mcp.LoggingMessageRequest) { + m := r.Params.Data.(map[string]any) + fmt.Println(m["msg"], m["value"]) + if nmsgs.Add(1) == 2 { // number depends on logger calls below + close(done) + } + }, + }) + + // Connect the server and client. + t1, t2 := mcp.NewInMemoryTransports() + ss, err := s.Connect(ctx, t1, nil) + if err != nil { + log.Fatal(err) + } + defer ss.Close() + cs, err := c.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + // Set the minimum log level to "info". + if err := cs.SetLoggingLevel(ctx, &mcp.SetLoggingLevelParams{Level: "info"}); err != nil { + log.Fatal(err) + } + + // Get a slog.Logger for the server session. + logger := slog.New(mcp.NewLoggingHandler(ss, nil)) + + // Log some things. + logger.Info("info shows up", "value", 1) + logger.Debug("debug doesn't show up", "value", 2) + logger.Warn("warn shows up", "value", 3) + + // Wait for them to arrive on the client. + // In a real application, the log messages would appear asynchronously + // while other work was happening. + <-done + + // Output: + // info shows up 1 + // warn shows up 3 +} + +// !-logging + // !+resources func Example_resources() { ctx := context.Background() From 8c42b69817fcdc11f1249282fc5c43ff7fa8feaa Mon Sep 17 00:00:00 2001 From: Suraj Bobade Date: Fri, 19 Sep 2025 22:31:57 +0530 Subject: [PATCH 221/221] SSE: Add support for SSE handler options (#508) mcp/sse: add support to provide for options for SSE transport - Add SSEOptions struct to define SSE handler options. - Update Signature of NewSSEHandler to accept SSEOptions [braking change] - Update unit test - Update example test Fixes https://github.com/modelcontextprotocol/go-sdk/issues/507, https://github.com/modelcontextprotocol/go-sdk/issues/503 --- examples/server/sse/main.go | 2 +- mcp/sse.go | 18 ++++++++++++++---- mcp/sse_example_test.go | 2 +- mcp/sse_test.go | 2 +- 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/examples/server/sse/main.go b/examples/server/sse/main.go index 27f9caed..0507dd60 100644 --- a/examples/server/sse/main.go +++ b/examples/server/sse/main.go @@ -65,6 +65,6 @@ func main() { default: return nil } - }) + }, nil) log.Fatal(http.ListenAndServe(addr, handler)) } diff --git a/mcp/sse.go b/mcp/sse.go index f39a0397..7f644918 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -43,12 +43,18 @@ import ( // [2024-11-05 version]: https://modelcontextprotocol.io/specification/2024-11-05/basic/transports type SSEHandler struct { getServer func(request *http.Request) *Server + opts SSEOptions onConnection func(*ServerSession) // for testing; must not block mu sync.Mutex sessions map[string]*SSEServerTransport } +// SSEOptions specifies options for an [SSEHandler]. +// for now, it is empty, but may be extended in future. +// https://github.com/modelcontextprotocol/go-sdk/issues/507 +type SSEOptions struct{} + // NewSSEHandler returns a new [SSEHandler] that creates and manages MCP // sessions created via incoming HTTP requests. // @@ -62,13 +68,17 @@ type SSEHandler struct { // The getServer function may return a distinct [Server] for each new // request, or reuse an existing server. If it returns nil, the handler // will return a 400 Bad Request. -// -// TODO(rfindley): add options. -func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler { - return &SSEHandler{ +func NewSSEHandler(getServer func(request *http.Request) *Server, opts *SSEOptions) *SSEHandler { + s := &SSEHandler{ getServer: getServer, sessions: make(map[string]*SSEServerTransport), } + + if opts != nil { + s.opts = *opts + } + + return s } // A SSEServerTransport is a logical SSE session created through a hanging GET diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index d06ea62b..6132d31e 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -31,7 +31,7 @@ func ExampleSSEHandler() { server := mcp.NewServer(&mcp.Implementation{Name: "adder", Version: "v0.0.1"}, nil) mcp.AddTool(server, &mcp.Tool{Name: "add", Description: "add two numbers"}, Add) - handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { return server }) + handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { return server }, nil) httpServer := httptest.NewServer(handler) defer httpServer.Close() diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 32a20bf3..25435ff3 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -24,7 +24,7 @@ func TestSSEServer(t *testing.T) { server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet"}, sayHi) - sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }) + sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }, nil) serverSessions := make(chan *ServerSession, 1) sseHandler.onConnection = func(ss *ServerSession) {