diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..ea77d0f8 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,26 @@ +{ + "name": "Go MCP SDK Dev Container", + "image": "mcr.microsoft.com/devcontainers/go", + "features": { + "ghcr.io/devcontainers/features/git-lfs:1": {} + }, + "customizations": { + "vscode": { + "extensions": [ + "golang.go", + "ms-vsliveshare.vsliveshare", + "VisualStudioExptTeam.vscodeintellicode", + "eamodio.gitlens", + "usernamehw.errorlens", + "aaron-bond.better-comments", + "GitHub.vscode-github-actions", + "vscode-icons-team.vscode-icons" + ], + "settings": { + "workbench.iconTheme": "vscode-icons" + } + } + }, + + "postCreateCommand": "go mod tidy" +} diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index be13052c..3a0b3d8c 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,12 +1,28 @@ -### PR Tips +### PR Guideline Typically, PRs should consist of a single commit, and so should generally follow -the [rules for Go commit messages](https://go.dev/wiki/CommitMessage), with the following -changes and additions: +the [rules for Go commit messages](https://go.dev/wiki/CommitMessage). -- Markdown is allowed. +You **must** follow the form: -- For a pervasive change, use "all" in the title instead of a package name. +``` +net/http: handle foo when bar + +[longer description here in the body] + +Fixes #12345 +``` +Notably, for the subject (the first line of description): +- the name of the package affected by the change goes before the colon +- the part after the colon uses the verb tense + phrase that completes the blank in, “this change modifies this package to ___________” +- the verb after the colon is lowercase +- there is no trailing period +- it should be kept as short as possible + +Additionally: + +- Markdown is allowed. +- For a pervasive change, use "all" in the title instead of a package name. - The PR description should provide context (why this change?) and describe the changes at a high level. Changes that are obvious from the diffs don't need to be mentioned. diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml new file mode 100644 index 00000000..cb6b841b --- /dev/null +++ b/.github/workflows/docs-check.yml @@ -0,0 +1,39 @@ +name: Docs Check +on: + workflow_dispatch: + pull_request: + paths: + - 'internal/readme/**' + - 'README.md' + - 'internal/docs/**' + - 'docs/**' + +permissions: + contents: read + +jobs: + docs-check: + runs-on: ubuntu-latest + steps: + - name: Set up Go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 + - name: Check out code + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 + - name: Check docs are up-to-date + run: | + go generate ./... + if [ -n "$(git status --porcelain)" ]; then + echo "ERROR: docs are not up-to-date!" + echo "" + echo "The docs differ from what would be generated by `go generate ./...`." + echo "Please update internal/**/*.src.md instead of directly editing README.md or docs/ files," + echo "then run `go generate ./...` to regenerate docs." + echo "" + echo "Changes:" + git status --porcelain + echo "" + echo "Diff:" + git diff + exit 1 + fi + echo "Docs are up-to-date." diff --git a/.github/workflows/readme-check.yml b/.github/workflows/readme-check.yml deleted file mode 100644 index b3398ead..00000000 --- a/.github/workflows/readme-check.yml +++ /dev/null @@ -1,38 +0,0 @@ -name: README Check -on: - workflow_dispatch: - pull_request: - paths: - - 'internal/readme/**' - - 'README.md' - -permissions: - contents: read - -jobs: - readme-check: - runs-on: ubuntu-latest - steps: - - name: Set up Go - uses: actions/setup-go@v5 - - name: Check out code - uses: actions/checkout@v4 - - name: Check README is up-to-date - run: | - cd internal/readme - make - if [ -n "$(git status --porcelain)" ]; then - echo "ERROR: README.md is not up-to-date!" - echo "" - echo "The README.md file differs from what would be generated by running 'make' in internal/readme/." - echo "Please update internal/readme/README.src.md instead of README.md directly," - echo "then run 'make' in the internal/readme/ directory to regenerate README.md." - echo "" - echo "Changes:" - git status --porcelain - echo "" - echo "Diff:" - git diff - exit 1 - fi - echo "README.md is up-to-date" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f2f8cb8e..03662463 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,11 +1,11 @@ name: Test on: # Manual trigger - workflow_dispatch: + workflow_dispatch: push: - branches: main + branches: [main] pull_request: - + permissions: contents: read @@ -13,31 +13,51 @@ jobs: lint: runs-on: ubuntu-latest steps: - - name: Set up Go - uses: actions/setup-go@v5 - - name: Check out code - uses: actions/checkout@v4 - - name: Check formatting - run: | - unformatted=$(gofmt -l .) - if [ -n "$unformatted" ]; then - echo "The following files are not properly formatted:" - echo "$unformatted" - exit 1 - fi - echo "All Go files are properly formatted" - + - name: Check out code + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 + - name: Set up Go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 + with: + go-version: "^1.23" + - name: Check formatting + run: | + unformatted=$(gofmt -l .) + if [ -n "$unformatted" ]; then + echo "The following files are not properly formatted:" + echo "$unformatted" + exit 1 + fi + echo "All Go files are properly formatted" + - name: Run Go vet + run: go vet ./... + - name: Run staticcheck + uses: dominikh/staticcheck-action@024238d2898c874f26d723e7d0ff4308c35589a2 # v1 + with: + version: "latest" + test: runs-on: ubuntu-latest strategy: matrix: - go: [ '1.23', '1.24' ] + go: ["1.23", "1.24", "1.25"] + steps: + - name: Check out code + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 + - name: Set up Go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 + with: + go-version: ${{ matrix.go }} + - name: Test + run: go test -v ./... + + race-test: + runs-on: ubuntu-latest steps: - - name: Set up Go - uses: actions/setup-go@v5 - with: - go-version: ${{ matrix.go }} - - name: Check out code - uses: actions/checkout@v4 - - name: Test - run: go test -v ./... + - name: Check out code + uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4 + - name: Set up Go + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5 + with: + go-version: "1.24" + - name: Test with -race + run: go test -v -race ./... diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..17341605 --- /dev/null +++ b/.gitignore @@ -0,0 +1,38 @@ +# Builds +*.exe +*.exe~ +dist/ +build/ +bin/ +*.tmp + +# IDE files +.vscode/ +*.code-workspace +.idea/ +*~ + +# Go Specific +*.prof +*.pprof +*.out +*.coverage +coverage.txt +coverage.html + +# OS generated files +# macOS +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Windows +Desktop.ini +$RECYCLE.BIN/ + +# Linux +.nfs* diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 40739735..726ac95f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -118,6 +118,12 @@ The CI system will automatically check that the README is up-to-date by running `make` and verifying no changes result. If you see a CI failure about the README being out of sync, follow the steps above to regenerate it. +## Timeouts + +If a contributor hasn't responded to issue questions or PR comments in two weeks, +the issue or PR may be closed. It can be reopened when the contributor can resume +work. + ## Code of conduct This project follows the [Go Community Code of Conduct](https://go.dev/conduct). diff --git a/README.md b/README.md index d4900674..9263e46b 100644 --- a/README.md +++ b/README.md @@ -1,46 +1,89 @@ -# MCP Go SDK +# MCP Go SDK v0.6.0 -[![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) +[![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/modelcontextprotocol/go-sdk) -This repository contains an unreleased implementation of the official Go -software development kit (SDK) for the Model Context Protocol (MCP). +***BREAKING CHANGES*** -**WARNING**: The SDK should be considered unreleased, and is currently unstable -and subject to breaking changes. Please test it out and file bug reports or API -proposals, but don't use it in real projects. See the issue tracker for known -issues and missing features. We aim to release a stable version of the SDK in -August, 2025. +This version contains minor breaking changes. +See the [release notes]( +https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.6.0) for details. -## Design +[![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) -The design doc for this SDK is at [design.md](./design/design.md), which was -initially reviewed at -[modelcontextprotocol/discussions/364](https://github.com/orgs/modelcontextprotocol/discussions/364). +This repository contains an implementation of the official Go software +development kit (SDK) for the Model Context Protocol (MCP). -Further design discussion should occur in -[issues](https://github.com/modelcontextprotocol/go-sdk/issues) (for concrete -proposals) or -[discussions](https://github.com/modelcontextprotocol/go-sdk/discussions) for -open-ended discussion. See CONTRIBUTING.md for details. +> [!IMPORTANT] +> The SDK is in release-candidate state, and is going to be tagged v1.0.0 +> soon (see https://github.com/modelcontextprotocol/go-sdk/issues/328). +> We do not anticipate significant API changes or instability. Please use it +> and [file issues](https://github.com/modelcontextprotocol/go-sdk/issues/new/choose). -## Package documentation +## Package / Feature documentation -The SDK consists of two importable packages: +The SDK consists of several importable packages: - The [`github.com/modelcontextprotocol/go-sdk/mcp`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp) package defines the primary APIs for constructing and using MCP clients and servers. - The - [`github.com/modelcontextprotocol/go-sdk/jsonschema`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonschema) - package provides an implementation of [JSON - Schema](https://json-schema.org/), used for MCP tool input and output schema. + [`github.com/modelcontextprotocol/go-sdk/jsonrpc`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonrpc) package is for users implementing + their own transports. +- The + [`github.com/modelcontextprotocol/go-sdk/auth`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth) + package provides some primitives for supporting oauth. +- The + [`github.com/modelcontextprotocol/go-sdk/oauthex`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/oauthex) + package provides extensions to the OAuth protocol, such as ProtectedResourceMetadata. -## Example +The SDK endeavors to implement the full MCP spec. The [`docs/`](/docs/) directory +contains feature documentation, mapping the MCP spec to the packages above. -In this example, an MCP client communicates with an MCP server running in a -sidecar process: +## Getting started + +To get started creating an MCP server, create an `mcp.Server` instance, add +features to it, and then run it over an `mcp.Transport`. For example, this +server adds a single simple tool, and then connects it to clients over +stdin/stdout: + +```go +package main + +import ( + "context" + "log" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type Input struct { + Name string `json:"name" jsonschema:"the name of the person to greet"` +} + +type Output struct { + Greeting string `json:"greeting" jsonschema:"the greeting to tell to the user"` +} + +func SayHi(ctx context.Context, req *mcp.CallToolRequest, input Input) (*mcp.CallToolResult, Output, error) { + return nil, Output{Greeting: "Hi " + input.Name}, nil +} + +func main() { + // Create a server with a single tool. + server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v1.0.0"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) + // Run the server over stdin/stdout, until the client disconnects + if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { + log.Fatal(err) + } +} +``` + +To communicate with that server, create an `mcp.Client` and connect it to the +corresponding server, by running the server command and communicating over its +stdin/stdout: ```go package main @@ -57,11 +100,11 @@ func main() { ctx := context.Background() // Create a new client, with no features. - client := mcp.NewClient("mcp-client", "v1.0.0", nil) + client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) // Connect to a server over stdin/stdout - transport := mcp.NewCommandTransport(exec.Command("myserver")) - session, err := client.Connect(ctx, transport) + transport := &mcp.CommandTransport{Command: exec.Command("myserver")} + session, err := client.Connect(ctx, transport, nil) if err != nil { log.Fatal(err) } @@ -85,52 +128,21 @@ func main() { } ``` -Here's an example of the corresponding server component, which communicates -with its client over stdin/stdout: - -```go -package main +The [`examples/`](/examples/) directory contains more example clients and +servers. -import ( - "context" - "log" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -type HiParams struct { - Name string `json:"name"` -} - -func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[HiParams]) (*mcp.CallToolResultFor[any], error) { - return &mcp.CallToolResultFor[any]{ - Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + params.Arguments.Name}}, - }, nil -} - -func main() { - // Create a server with a single tool. - server := mcp.NewServer("greeter", "v1.0.0", nil) - server.AddTools( - mcp.NewServerTool("greet", "say hi", SayHi, mcp.Input( - mcp.Property("name", mcp.Description("the name of the person to greet")), - )), - ) - // Run the server over stdin/stdout, until the client disconnects - if err := server.Run(context.Background(), mcp.NewStdioTransport()); err != nil { - log.Fatal(err) - } -} -``` +## Contributing -The `examples/` directory contains more example clients and servers. +We welcome contributions to the SDK! Please see See +[CONTRIBUTING.md](/CONTRIBUTING.md) for details of how to contribute. -## Acknowledgements +## Acknowledgements / Alternatives -Several existing Go MCP SDKs inspired the development and design of this -official SDK, notably [mcp-go](https://github.com/mark3labs/mcp-go), authored -by Ed Zynda. We are grateful to Ed as well as the other contributors to mcp-go, -and to authors and contributors of other SDKs such as +Several third party Go MCP SDKs inspired the development and design of this +official SDK, and continue to be viable alternatives, notably +[mcp-go](https://github.com/mark3labs/mcp-go), originally authored by Ed Zynda. +We are grateful to Ed as well as the other contributors to mcp-go, and to +authors and contributors of other SDKs such as [mcp-golang](https://github.com/metoro-io/mcp-golang) and [go-mcp](https://github.com/ThinkInAIXYZ/go-mcp). Thanks to their work, there is a thriving ecosystem of Go MCP clients and servers. diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 00000000..7cc0074a --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,122 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "context" + "errors" + "net/http" + "slices" + "strings" + "time" +) + +// TokenInfo holds information from a bearer token. +type TokenInfo struct { + Scopes []string + Expiration time.Time + // TODO: add standard JWT fields + Extra map[string]any +} + +// The error that a TokenVerifier should return if the token cannot be verified. +var ErrInvalidToken = errors.New("invalid token") + +// The error that a TokenVerifier should return for OAuth-specific protocol errors. +var ErrOAuth = errors.New("oauth error") + +// A TokenVerifier checks the validity of a bearer token, and extracts information +// from it. If verification fails, it should return an error that unwraps to ErrInvalidToken. +// The HTTP request is provided in case verifying the token involves checking it. +type TokenVerifier func(ctx context.Context, token string, req *http.Request) (*TokenInfo, error) + +// RequireBearerTokenOptions are options for [RequireBearerToken]. +type RequireBearerTokenOptions struct { + // The URL for the resource server metadata OAuth flow, to be returned as part + // of the WWW-Authenticate header. + ResourceMetadataURL string + // The required scopes. + Scopes []string +} + +type tokenInfoKey struct{} + +// TokenInfoFromContext returns the [TokenInfo] stored in ctx, or nil if none. +func TokenInfoFromContext(ctx context.Context) *TokenInfo { + ti := ctx.Value(tokenInfoKey{}) + if ti == nil { + return nil + } + return ti.(*TokenInfo) +} + +// RequireBearerToken returns a piece of middleware that verifies a bearer token using the verifier. +// If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds. +// If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header +// is populated to enable [protected resource metadata]. +// + +// +// [protected resource metadata]: https://datatracker.ietf.org/doc/rfc9728 +func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions) func(http.Handler) http.Handler { + // Based on typescript-sdk/src/server/auth/middleware/bearerAuth.ts. + + return func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenInfo, errmsg, code := verify(r, verifier, opts) + if code != 0 { + if code == http.StatusUnauthorized || code == http.StatusForbidden { + if opts != nil && opts.ResourceMetadataURL != "" { + w.Header().Add("WWW-Authenticate", "Bearer resource_metadata="+opts.ResourceMetadataURL) + } + } + http.Error(w, errmsg, code) + return + } + r = r.WithContext(context.WithValue(r.Context(), tokenInfoKey{}, tokenInfo)) + handler.ServeHTTP(w, r) + }) + } +} + +func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenOptions) (_ *TokenInfo, errmsg string, code int) { + // Extract bearer token. + authHeader := req.Header.Get("Authorization") + fields := strings.Fields(authHeader) + if len(fields) != 2 || strings.ToLower(fields[0]) != "bearer" { + return nil, "no bearer token", http.StatusUnauthorized + } + + // Verify the token and get information from it. + tokenInfo, err := verifier(req.Context(), fields[1], req) + if err != nil { + if errors.Is(err, ErrInvalidToken) { + return nil, err.Error(), http.StatusUnauthorized + } + if errors.Is(err, ErrOAuth) { + return nil, err.Error(), http.StatusBadRequest + } + return nil, err.Error(), http.StatusInternalServerError + } + + // Check scopes. All must be present. + if opts != nil { + // Note: quadratic, but N is small. + for _, s := range opts.Scopes { + if !slices.Contains(tokenInfo.Scopes, s) { + return nil, "insufficient scope", http.StatusForbidden + } + } + } + + // Check expiration. + if tokenInfo.Expiration.IsZero() { + return nil, "token missing expiration", http.StatusUnauthorized + } + if tokenInfo.Expiration.Before(time.Now()) { + return nil, "token expired", http.StatusUnauthorized + } + return tokenInfo, "", 0 +} diff --git a/auth/auth_test.go b/auth/auth_test.go new file mode 100644 index 00000000..ef8ea7b3 --- /dev/null +++ b/auth/auth_test.go @@ -0,0 +1,78 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "context" + "errors" + "net/http" + "testing" + "time" +) + +func TestVerify(t *testing.T) { + verifier := func(_ context.Context, token string, _ *http.Request) (*TokenInfo, error) { + switch token { + case "valid": + return &TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil + case "invalid": + return nil, ErrInvalidToken + case "oauth": + return nil, ErrOAuth + case "noexp": + return &TokenInfo{}, nil + case "expired": + return &TokenInfo{Expiration: time.Now().Add(-time.Hour)}, nil + default: + return nil, errors.New("unknown") + } + } + + for _, tt := range []struct { + name string + opts *RequireBearerTokenOptions + header string + wantMsg string + wantCode int + }{ + { + "valid", nil, "Bearer valid", + "", 0, + }, + { + "bad header", nil, "Barer valid", + "no bearer token", 401, + }, + { + "invalid", nil, "bearer invalid", + "invalid token", 401, + }, + { + "oauth error", nil, "Bearer oauth", + "oauth error", 400, + }, + { + "no expiration", nil, "Bearer noexp", + "token missing expiration", 401, + }, + { + "expired", nil, "Bearer expired", + "token expired", 401, + }, + { + "missing scope", &RequireBearerTokenOptions{Scopes: []string{"s1"}}, "Bearer valid", + "insufficient scope", 403, + }, + } { + t.Run(tt.name, func(t *testing.T) { + _, gotMsg, gotCode := verify(&http.Request{ + Header: http.Header{"Authorization": {tt.header}}, + }, verifier, tt.opts) + if gotMsg != tt.wantMsg || gotCode != tt.wantCode { + t.Errorf("got (%q, %d), want (%q, %d)", gotMsg, gotCode, tt.wantMsg, tt.wantCode) + } + }) + } +} diff --git a/copyright_test.go b/copyright_test.go new file mode 100644 index 00000000..287e5948 --- /dev/null +++ b/copyright_test.go @@ -0,0 +1,57 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package gosdk + +import ( + "fmt" + "go/parser" + "go/token" + "io/fs" + "path/filepath" + "regexp" + "strings" + "testing" +) + +func TestCopyrightHeaders(t *testing.T) { + var re = regexp.MustCompile(`Copyright \d{4} The Go MCP SDK Authors. All rights reserved. +Use of this source code is governed by an MIT-style +license that can be found in the LICENSE file.`) + + err := filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + // Skip directories starting with "." or "_", and testdata directories. + if d.IsDir() && d.Name() != "." && + (strings.HasPrefix(d.Name(), ".") || + strings.HasPrefix(d.Name(), "_") || + filepath.Base(d.Name()) == "testdata") { + + return filepath.SkipDir + } + + // Skip non-go files. + if !strings.HasSuffix(path, ".go") { + return nil + } + + // Check the copyright header. + f, err := parser.ParseFile(token.NewFileSet(), path, nil, parser.ParseComments|parser.PackageClauseOnly) + if err != nil { + return fmt.Errorf("parsing %s: %v", path, err) + } + if len(f.Comments) == 0 { + t.Errorf("File %s must start with a copyright header matching %s", path, re) + } else if !re.MatchString(f.Comments[0].Text()) { + t.Errorf("Header comment for %s does not match expected copyright header.\ngot:\n%s\nwant matching:%s", path, f.Comments[0].Text(), re) + } + return nil + }) + if err != nil { + t.Fatal(err) + } +} diff --git a/design/design.md b/design/design.md index b52e9c10..fa2270e2 100644 --- a/design/design.md +++ b/design/design.md @@ -100,11 +100,7 @@ The `CommandTransport` is the client side of the stdio transport, and connects b ```go // A CommandTransport is a [Transport] that runs a command and communicates // with it over stdin/stdout, using newline-delimited JSON. -type CommandTransport struct { /* unexported fields */ } - -// NewCommandTransport returns a [CommandTransport] that runs the given command -// and communicates with it over stdin/stdout. -func NewCommandTransport(cmd *exec.Command) *CommandTransport +type CommandTransport struct { Command *exec.Command } // Connect starts the command, and connects to it over stdin/stdout. func (*CommandTransport) Connect(ctx context.Context) (Connection, error) { @@ -115,9 +111,7 @@ The `StdioTransport` is the server side of the stdio transport, and connects by ```go // A StdioTransport is a [Transport] that communicates using newline-delimited // JSON over stdin/stdout. -type StdioTransport struct { /* unexported fields */ } - -func NewStdioTransport() *StdioTransport +type StdioTransport struct { } func (t *StdioTransport) Connect(context.Context) (Connection, error) ``` @@ -128,6 +122,8 @@ The HTTP transport APIs are even more asymmetrical. Since connections are initia Importantly, since they serve many connections, the HTTP handlers must accept a callback to get an MCP server for each new session. As described below, MCP servers can optionally connect to multiple clients. This allows customization of per-session servers: if the MCP server is stateless, the user can return the same MCP server for each connection. On the other hand, if any per-session customization is required, it is possible by returning a different `Server` instance for each connection. +Both the SSE and Streamable HTTP server transports are http.Handlers which serve messages to their associated connection. Consequently, they can be connected at most once. + ```go // SSEHTTPHandler is an http.Handler that serves SSE-based MCP sessions as defined by // the 2024-11-05 version of the MCP protocol. @@ -153,26 +149,10 @@ By default, the SSE handler creates messages endpoints with the `?sessionId=...` ```go // A SSEServerTransport is a logical SSE session created through a hanging GET // request. -// -// When connected, it returns the following [Connection] implementation: -// - Writes are SSE 'message' events to the GET response. -// - Reads are received from POSTs to the session endpoint, via -// [SSEServerTransport.ServeHTTP]. -// - Close terminates the hanging GET. -type SSEServerTransport struct { /* ... */ } - -// NewSSEServerTransport creates a new SSE transport for the given messages -// endpoint, and hanging GET response. -// -// Use [SSEServerTransport.Connect] to initiate the flow of messages. -// -// The transport is itself an [http.Handler]. It is the caller's responsibility -// to ensure that the resulting transport serves HTTP requests on the given -// session endpoint. -// -// Most callers should instead use an [SSEHandler], which transparently handles -// the delegation to SSEServerTransports. -func NewSSEServerTransport(endpoint string, w http.ResponseWriter) *SSEServerTransport +type SSEServerTransport struct { + Endpoint string + Response http.ResponseWriter +} // ServeHTTP handles POST requests to the transport endpoint. func (*SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) @@ -185,20 +165,14 @@ func (*SSEServerTransport) Connect(context.Context) (Connection, error) The SSE client transport is simpler, and hopefully self-explanatory. ```go -type SSEClientTransport struct { /* ... */ } - -// SSEClientTransportOptions provides options for the [NewSSEClientTransport] -// constructor. -type SSEClientTransportOptions struct { +type SSEClientTransport struct { + // Endpoint is the SSE endpoint to connect to. + Endpoint string // HTTPClient is the client to use for making HTTP requests. If nil, // http.DefaultClient is used. HTTPClient *http.Client } -// NewSSEClientTransport returns a new client transport that connects to the -// SSE server at the provided URL. -func NewSSEClientTransport(url string, opts *SSEClientTransportOptions) (*SSEClientTransport, error) - // Connect connects through the client endpoint. func (*SSEClientTransport) Connect(ctx context.Context) (Connection, error) ``` @@ -218,23 +192,22 @@ func (*StreamableHTTPHandler) Close() error // session ID, not an endpoint, along with the HTTP response for the request // that created the session. It is the caller's responsibility to delegate // requests to this session. -type StreamableServerTransport struct { /* ... */ } -func NewStreamableServerTransport(sessionID string) *StreamableServerTransport +type StreamableServerTransport struct { + // SessionID is the ID of this session. + SessionID string + // Storage for events, to enable stream resumption. + EventStore EventStore +} func (*StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) func (*StreamableServerTransport) Connect(context.Context) (Connection, error) // The streamable client handles reconnection transparently to the user. -type StreamableClientTransport struct { /* ... */ } - -// StreamableClientTransportOptions provides options for the -// [NewStreamableClientTransport] constructor. -type StreamableClientTransportOptions struct { - // HTTPClient is the client to use for making HTTP requests. If nil, - // http.DefaultClient is used. - HTTPClient *http.Client +type StreamableClientTransport struct { + Endpoint string + HTTPClient *http.Client + ReconnectOptions *StreamableReconnectOptions } -func NewStreamableClientTransport(url string, opts *StreamableClientTransportOptions) *StreamableClientTransport func (*StreamableClientTransport) Connect(context.Context) (Connection, error) ``` @@ -257,8 +230,10 @@ func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) // A LoggingTransport is a [Transport] that delegates to another transport, // writing RPC logs to an io.Writer. -type LoggingTransport struct { /* ... */ } -func NewLoggingTransport(delegate Transport, w io.Writer) *LoggingTransport +type LoggingTransport struct { + Delegate Transport + Writer io.Writer +} ``` ### Protocol types @@ -323,7 +298,7 @@ Sessions are created from either `Client` or `Server` using the `Connect` method ```go type Client struct { /* ... */ } -func NewClient(name, version string, opts *ClientOptions) *Client +func NewClient(impl *Implementation, opts *ClientOptions) *Client func (*Client) Connect(context.Context, Transport) (*ClientSession, error) func (*Client) Sessions() iter.Seq[*ClientSession] // Methods for adding/removing client features are described below. @@ -338,7 +313,7 @@ func (*ClientSession) Wait() error // For example: ClientSession.ListTools. type Server struct { /* ... */ } -func NewServer(name, version string, opts *ServerOptions) *Server +func NewServer(impl *Implementation, opts *ServerOptions) *Server func (*Server) Connect(context.Context, Transport) (*ServerSession, error) func (*Server) Sessions() iter.Seq[*ServerSession] // Methods for adding/removing server features are described below. @@ -356,9 +331,11 @@ func (*ServerSession) Wait() error Here's an example of these APIs from the client side: ```go -client := mcp.NewClient("mcp-client", "v1.0.0", nil) +client := mcp.NewClient(&mcp.Implementation{Name:"mcp-client", Version:"v1.0.0"}, nil) // Connect to a server over stdin/stdout -transport := mcp.NewCommandTransport(exec.Command("myserver")) +transport := &mcp.CommandTransport{ + Command: exec.Command("myserver"}, +} session, err := client.Connect(ctx, transport) if err != nil { ... } // Call a tool on the server. @@ -371,13 +348,12 @@ A server that can handle that client call would look like this: ```go // Create a server with a single tool. -server := mcp.NewServer("greeter", "v1.0.0", nil) -server.AddTools(mcp.NewServerTool("greet", "say hi", SayHi)) +server := mcp.NewServer(&mcp.Implementation{Name:"greeter", Version:"v1.0.0"}, nil) +mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) // Run the server over stdin/stdout, until the client disconnects. -transport := mcp.NewStdioTransport() -session, err := server.Connect(ctx, transport) -... -return session.Wait() +if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { + log.Fatal(err) +} ``` For convenience, we provide `Server.Run` to handle the common case of running a session until the client disconnects: @@ -432,15 +408,12 @@ We provide a mechanism to add MCP-level middleware on the both the client and se ```go // A MethodHandler handles MCP messages. -// The params argument is an XXXParams struct pointer, such as *GetPromptParams. -// For methods, a MethodHandler must return either an XXResult struct pointer and a nil error, or -// nil with a non-nil error. -// For notifications, a MethodHandler must return nil, nil. -type MethodHandler[S Session] func( - ctx context.Context, _ *S, method string, params Params) (result Result, err error) +// For methods, exactly one of the return values must be nil. +// For notifications, both must be nil. +type MethodHandler func(ctx context.Context, method string, req Request) (result Result, err error) // Middleware is a function from MethodHandlers to MethodHandlers. -type Middleware[S Session] func(MethodHandler[S]) MethodHandler[S] +type Middleware func(MethodHandler) MethodHandler // AddMiddleware wraps the client/server's current method handler using the provided // middleware. Middleware is applied from right to left, so that the first one @@ -448,17 +421,17 @@ type Middleware[S Session] func(MethodHandler[S]) MethodHandler[S] // // For example, AddMiddleware(m1, m2, m3) augments the server method handler as // m1(m2(m3(handler))). -func (c *Client) AddSendingMiddleware(middleware ...Middleware[*ClientSession]) -func (c *Client) AddReceivingMiddleware(middleware ...Middleware[*ClientSession]) -func (s *Server) AddSendingMiddleware(middleware ...Middleware[*ServerSession]) -func (s *Server) AddReceivingMiddleware(middleware ...Middleware[*ServerSession]) +func (c *Client) AddSendingMiddleware(middleware ...Middleware) +func (c *Client) AddReceivingMiddleware(middleware ...Middleware) +func (s *Server) AddSendingMiddleware(middleware ...Middleware) +func (s *Server) AddReceivingMiddleware(middleware ...Middleware) ``` As an example, this code adds server-side logging: ```go -func withLogging(h mcp.MethodHandler[*ServerSession]) mcp.MethodHandler[*ServerSession]{ - return func(ctx context.Context, s *mcp.ServerSession, method string, params any) (res any, err error) { +func withLogging(h mcp.MethodHandler) mcp.MethodHandler{ + return func(ctx context.Context, method string, req mcp.Request) (res mcp.Result, err error) { log.Printf("request: %s %v", method, params) defer func() { log.Printf("response: %v, %v", res, err) }() return h(ctx, s , method, params) @@ -472,7 +445,7 @@ server.AddReceivingMiddleware(withLogging) #### Rate Limiting -Rate limiting can be configured using middleware. Please see [examples/rate-limiting](] for an example on how to implement this. +Rate limiting can be configured using middleware. Please see [examples/rate-limiting]() for an example on how to implement this. ### Errors @@ -603,14 +576,14 @@ type ClientOptions struct { ### Tools -A `Tool` is a logical MCP tool, generated from the MCP spec, and a `ServerTool` is a tool bound to a tool handler. +A `Tool` is a logical MCP tool, generated from the MCP spec. -A tool handler accepts `CallToolParams` and returns a `CallToolResult`. However, since we want to bind tools to Go input types, it is convenient in associated APIs to make `CallToolParams` generic, with a type parameter `TArgs` for the tool argument type. This allows tool APIs to manage the marshalling and unmarshalling of tool inputs for their caller. The bound `ServerTool` type expects a `json.RawMessage` for its tool arguments, but the `NewServerTool` constructor described below provides a mechanism to bind a typed handler. +A tool handler accepts `CallToolParams` and returns a `CallToolResult`. However, since we want to bind tools to Go input types, it is convenient in associated APIs to have a generic version of `CallToolParams`, with a type parameter `In` for the tool argument type, as well as a generic version of for `CallToolResult`. This allows tool APIs to manage the marshalling and unmarshalling of tool inputs for their caller. ```go -type CallToolParams[TArgs any] struct { +type CallToolParamsFor[In any] struct { Meta Meta `json:"_meta,omitempty"` - Arguments TArgs `json:"arguments,omitempty"` + Arguments In `json:"arguments,omitempty"` Name string `json:"name"` } @@ -621,23 +594,32 @@ type Tool struct { Name string `json:"name"` } -type ToolHandler[TArgs] func(context.Context, *ServerSession, *CallToolParams[TArgs]) (*CallToolResult, error) +// A ToolHandlerFor handles a call to tools/call with typed arguments and results. +type ToolHandlerFor[In, Out any] func(context.Context, *ServerRequest[*CallToolParamsFor[In]]) (*CallToolResultFor[Out], error) -type ServerTool struct { - Tool Tool - Handler ToolHandler[json.RawMessage] -} ``` -Add tools to a server with `AddTools`: +Add tools to a server with the `AddTool` method or function. The function is generic and infers schemas from the handler +arguments: + +```go +func (s *Server) AddTool(t *Tool, h ToolHandler) +func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) +``` + +```go +mcp.AddTool(server, &mcp.Tool{Name: "add", Description: "add numbers"}, addHandler) +mcp.AddTool(server, &mcp.Tool{Name: "subtract", Description: "subtract numbers"}, subHandler) +``` +The `AddTool` method requires an input schema, and optionally an output one. It will not modify them. +The handler should accept a `CallToolParams` and return a `CallToolResult`. ```go -server.AddTools( - mcp.NewServerTool("add", "add numbers", addHandler), - mcp.NewServerTool("subtract, subtract numbers", subHandler)) +t := &Tool{Name: ..., Description: ..., InputSchema: &jsonschema.Schema{...}} +server.AddTool(t, myHandler) ``` -Remove them by name with `RemoveTools`: +Tools can be removed by name with `RemoveTools`: ```go server.RemoveTools("add", "subtract") @@ -650,53 +632,30 @@ A tool's input schema, expressed as a [JSON Schema](https://json-schema.org), pr Both of these have their advantages and disadvantages. Reflection is nice, because it allows you to bind directly to a Go API, and means that the JSON schema of your API is compatible with your Go types by construction. It also means that concerns like parsing and validation can be handled automatically. However, it can become cumbersome to express the full breadth of JSON schema using Go types or struct tags, and sometimes you want to express things that aren’t naturally modeled by Go types, like unions. Explicit schemas are simple and readable, and give the caller full control over their tool definition, but involve significant boilerplate. -We have found that a hybrid model works well, where the _initial_ schema is derived using reflection, but any customization on top of that schema is applied using variadic options. We achieve this using a `NewServerTool` helper, which generates the schema from the input type, and wraps the handler to provide parsing and validation. The schema (and potentially other features) can be customized using ToolOptions. +We provide both ways. The `jsonschema.For[T]` function will infer a schema, and it is called by the `AddTool` generic function. +Users can also call it themselves, or build a schema directly as a struct literal. They can still use the `AddTool` function to +create a typed handler, since `AddTool` doesn't touch schemas that are already present. -```go -// NewServerTool creates a Tool using reflection on the given handler. -func NewServerTool[TArgs any](name, description string, handler ToolHandler[TArgs], opts …ToolOption) *ServerTool - -type ToolOption interface { /* ... */ } -``` -`NewServerTool` determines the input schema for a Tool from the `TArgs` type. Each struct field that would be marshaled by `encoding/json.Marshal` becomes a property of the schema. The property is required unless the field's `json` tag specifies "omitempty" or "omitzero" (new in Go 1.24). For example, given this struct: +If the tool's `InputSchema` is nil, it is inferred from the `In` type parameter. If the `OutputSchema` is nil, it is inferred from the `Out` type parameter (unless `Out` is `any`). +For example, given this handler: ```go -struct { - Name string `json:"name"` - Count int `json:"count,omitempty"` - Choices []string - Password []byte `json:"-"` +type AddParams struct { + X int `json:"x"` + Y int `json:"y"` } -``` - -"name" and "Choices" are required, while "count" is optional. - -As of this writing, the only `ToolOption` is `Input`, which allows customizing the input schema of the tool using schema options. These schema options are recursive, in the sense that they may also be applied to properties. - -```go -func Input(...SchemaOption) ToolOption - -type Property(name string, opts ...SchemaOption) SchemaOption -type Description(desc string) SchemaOption -// etc. -``` - -For example: -```go -NewServerTool(name, description, handler, - Input(Property("count", Description("size of the inventory")))) +func addHandler(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[AddParams]]) (*mcp.CallToolResultFor[int], error) { + return &mcp.CallToolResultFor[int]{StructuredContent: req.Params.Arguments.X + req.Params.Arguments.Y}, nil +} ``` -The most recent JSON Schema spec defines over 40 keywords. Providing them all as options would bloat the API despite the fact that most would be very rarely used. For less common keywords, use the `Schema` option to set the schema explicitly: - +You can add it to a server like this: ```go -NewServerTool(name, description, handler, - Input(Property("Choices", Schema(&jsonschema.Schema{UniqueItems: true})))) +mcp.AddTool(server, &mcp.Tool{Name: "add", Description: "add numbers"}, addHandler) ``` - -Schemas are validated on the server before the tool handler is called. +The input schema will be inferred from `AddParams`, and the output schema from `int`. Since all the fields of the Tool struct are exported, a Tool can also be created directly with assignment or a struct literal. @@ -704,29 +663,15 @@ Client sessions can call the spec method `ListTools` or an iterator method `Tool ```go func (cs *ClientSession) CallTool(context.Context, *CallToolParams[json.RawMessage]) (*CallToolResult, error) - -func CallTool[TArgs any](context.Context, *ClientSession, *CallToolParams[TArgs]) (*CallToolResult, error) ``` -**Differences from mcp-go**: using variadic options to configure tools was significantly inspired by mcp-go. However, the distinction between `ToolOption` and `SchemaOption` allows for recursive application of schema options. For example, that limitation is visible in [this code](https://github.com/DCjanus/dida365-mcp-server/blob/master/cmd/mcp/tools.go#L315), which must resort to untyped maps to express a nested schema. +**Differences from mcp-go**: We provide a full JSON Schema implementation for validating tool input schemas against incoming arguments. The `jsonschema.Schema` type provides exported features for all keywords in the JSON Schema draft2020-12 spec. Tool definers can use it to construct any schema they want. The `jsonschema.For[T]` function can infer a schema from a Go struct. These combined features eliminate the need for variadic arguments to construct tool schemas. -Additionally, the `NewServerTool` helper provides a means for building a tool from a Go function using reflection, that automatically handles parsing and validation of inputs. - -We provide a full JSON Schema implementation for validating tool input schemas against incoming arguments. The `jsonschema.Schema` type provides exported features for all keywords in the JSON Schema draft2020-12 spec. Tool definers can use it to construct any schema they want, so there is no need to provide options for all of them. When combined with schema inference from input structs, we found that we needed only three options to cover the common cases, instead of mcp-go's 23. For example, we will provide `Enum`, which occurs 125 times in open source code, but not MinItems, MinLength or MinProperties, which each occur only once (and in an SDK that wraps mcp-go). - -For registering tools, we provide only `AddTools`; mcp-go's `SetTools`, `AddTool`, `AddSessionTool`, and `AddSessionTools` are deemed unnecessary. (Similarly for Delete/Remove). +For registering tools, we provide only a `Server.AddTool` method; mcp-go's `SetTools`, `AddTool`, `AddSessionTool`, and `AddSessionTools` are deemed unnecessary. (Similarly for Delete/Remove). The `AddTool` generic function combines schema inference with registration, providing a easy way to register many tools. ### Prompts -Use `NewServerPrompt` to create a prompt. As with tools, prompt argument schemas can be inferred from a struct, or obtained from options. - -```go -func NewServerPrompt[TReq any](name, description string, - handler func(context.Context, *ServerSession, TReq) (*GetPromptResult, error), - opts ...PromptOption) *ServerPrompt -``` - -Use `AddPrompts` to add prompts to the server, and `RemovePrompts` +Use `AddPrompt` to add a prompt to the server, and `RemovePrompts` to remove them by name. ```go @@ -734,18 +679,19 @@ type codeReviewArgs struct { Code string `json:"code"` } -func codeReviewHandler(context.Context, *ServerSession, codeReviewArgs) {...} +func codeReviewHandler(context.Context, *ServerSession, *mcp.GetPromptParams) (*mcp.GetPromptResult, error) {...} -server.AddPrompts( - NewServerPrompt("code_review", "review code", codeReviewHandler, - Argument("code", Description("the code to review")))) +server.AddPrompt( + &mcp.Prompt{Name: "code_review", Description: "review code"}, + codeReviewHandler, +) server.RemovePrompts("code_review") ``` Client sessions can call the spec method `ListPrompts` or the iterator method `Prompts` to list the available prompts, and the spec method `GetPrompt` to get one. -**Differences from mcp-go**: We provide a `NewServerPrompt` helper to bind a prompt handler to a Go function using reflection to derive its arguments. We provide `RemovePrompts` to remove prompts from the server. +**Differences from mcp-go**: We provide `RemovePrompts` to remove prompts from the server. ### Resources and resource templates @@ -757,25 +703,11 @@ type ResourceHandler func(context.Context, *ServerSession, *ReadResourceParams) The arguments include the `ServerSession` so the handler can observe the client's roots. The handler should return the resource contents in a `ReadResourceResult`, calling either `NewTextResourceContents` or `NewBlobResourceContents`. If the handler omits the URI or MIME type, the server will populate them from the resource. -The `ServerResource` and `ServerResourceTemplate` types hold the association between the resource and its handler: - -```go -type ServerResource struct { - Resource Resource - Handler ResourceHandler -} - -type ServerResourceTemplate struct { - Template ResourceTemplate - Handler ResourceHandler -} -``` - -To add a resource or resource template to a server, users call the `AddResources` and `AddResourceTemplates` methods with one or more `ServerResource`s or `ServerResourceTemplate`s. We also provide methods to remove them. +To add a resource or resource template to a server, users call the `AddResource` and `AddResourceTemplate` methods. We also provide methods to remove them. ```go -func (*Server) AddResources(...*ServerResource) -func (*Server) AddResourceTemplates(...*ServerResourceTemplate) +func (*Server) AddResource(*Resource, ResourceHandler) +func (*Server) AddResourceTemplate(*ResourceTemplate, ResourceHandler) func (s *Server) RemoveResources(uris ...string) func (s *Server) RemoveResourceTemplates(uriTemplates ...string) @@ -783,37 +715,32 @@ func (s *Server) RemoveResourceTemplates(uriTemplates ...string) The `ReadResource` method finds a resource or resource template matching the argument URI and calls its associated handler. -To read files from the local filesystem, we recommend using `FileResourceHandler` to construct a handler: - -```go -// FileResourceHandler returns a ResourceHandler that reads paths using dir as a root directory. -// It protects against path traversal attacks. -// It will not read any file that is not in the root set of the client requesting the resource. -func (*Server) FileResourceHandler(dir string) ResourceHandler -``` - -Here is an example: - -```go -// Safely read "/public/puppies.txt". -s.AddResources(&mcp.ServerResource{ - Resource: mcp.Resource{URI: "file:///puppies.txt"}, - Handler: s.FileReadResourceHandler("/public")}) -``` - Server sessions also support the spec methods `ListResources` and `ListResourceTemplates`, and the corresponding iterator methods `Resources` and `ResourceTemplates`. -**Differences from mcp-go**: for symmetry with tools and prompts, we use `AddResources` rather than `AddResource`. Additionally, the `ResourceHandler` returns a `ReadResourceResult`, rather than just its content, for compatibility with future evolution of the spec. +**Differences from mcp-go**: The `ResourceHandler` returns a `ReadResourceResult`, rather than just its content, for compatibility with future evolution of the spec. #### Subscriptions -ClientSessions can manage change notifications on particular resources: +##### Client-Side Usage + +Use the Subscribe and Unsubscribe methods on a ClientSession to start or stop receiving updates for a specific resource. ```go func (*ClientSession) Subscribe(context.Context, *SubscribeParams) error func (*ClientSession) Unsubscribe(context.Context, *UnsubscribeParams) error ``` +To process incoming update notifications, you must provide a ResourceUpdatedHandler in your ClientOptions. The SDK calls this function automatically whenever the server sends a notification for a resource you're subscribed to. + +```go +type ClientOptions struct { + ... + ResourceUpdatedHandler func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams) +} +``` + +##### Server-Side Implementation + The server does not implement resource subscriptions. It passes along subscription requests to the user, and supplies a method to notify clients of changes. It tracks which sessions have subscribed to which resources so the user doesn't have to. If a server author wants to support resource subscriptions, they must provide handlers to be called when clients subscribe and unsubscribe. It is an error to provide only one of these handlers. @@ -821,17 +748,17 @@ If a server author wants to support resource subscriptions, they must provide ha ```go type ServerOptions struct { ... - // Function called when a client session subscribes to a resource. - SubscribeHandler func(context.Context, *SubscribeParams) error - // Function called when a client session unsubscribes from a resource. - UnsubscribeHandler func(context.Context, *UnsubscribeParams) error + // Function called when a client session subscribes to a resource. + SubscribeHandler func(context.Context, *ServerRequest[*SubscribeParams]) error + // Function called when a client session unsubscribes from a resource. + UnsubscribeHandler func(context.Context, *ServerRequest[*UnsubscribeParams]) error } ``` User code should call `ResourceUpdated` when a subscribed resource changes. ```go -func (*Server) ResourceUpdated(context.Context, *ResourceUpdatedNotification) error +func (*Server) ResourceUpdated(context.Context, *ResourceUpdatedNotificationParams) error ``` The server routes these notifications to the server sessions that subscribed to the resource. @@ -843,10 +770,10 @@ When a list of tools, prompts or resources changes as the result of an AddXXX or ```go type ClientOptions struct { ... - ToolListChangedHandler func(context.Context, *ClientSession, *ToolListChangedParams) - PromptListChangedHandler func(context.Context, *ClientSession, *PromptListChangedParams) + ToolListChangedHandler func(context.Context, *ClientRequest[*ToolListChangedParams]) + PromptListChangedHandler func(context.Context, *ClientRequest[*PromptListChangedParams]) // For both resources and resource templates. - ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams) + ResourceListChangedHandler func(context.Context, *ClientRequest[*ResourceListChangedParams]) } ``` @@ -857,13 +784,10 @@ type ClientOptions struct { Clients call the spec method `Complete` to request completions. If a server installs a `CompletionHandler`, it will be called when the client sends a completion request. ```go -// A CompletionHandler handles a call to completion/complete. -type CompletionHandler func(context.Context, *ServerSession, *CompleteParams) (*CompleteResult, error) - type ServerOptions struct { ... // If non-nil, called when a client sends a completion request. - CompletionHandler CompletionHandler + CompletionHandler func(context.Context, *ServerRequest[*CompleteParams]) (*CompleteResult, error) } ``` diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 00000000..b4268c85 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,40 @@ + +# Features + +These docs mirror the [official MCP spec](https://modelcontextprotocol.io/specification/2025-06-18). +Use the index below to learn how the SDK implements a particular aspect of the +protocol. + +## Base Protocol + +1. [Lifecycle (Clients, Servers, and Sessions)](protocol.md#lifecycle). +1. [Transports](protocol.md#transports) + 1. [Stdio transport](protocol.md#stdio-transport) + 1. [Streamable transport](protocol.md#streamable-transport) + 1. [Custom transports](protocol.md#stateless-mode) +1. [Authorization](protocol.md#authorization) +1. [Security](protocol.md#security) +1. [Utilities](protocol.md#utilities) + 1. [Cancellation](utilities.md#cancellation) + 1. [Ping](utilities.md#ping) + 1. [Progress](utilities.md#progress) + +## Client Features + +1. [Roots](client.md#roots) +1. [Sampling](client.md#sampling) +1. [Elicitation](clients.md#elicitation) + +## Server Features + +1. [Prompts](server.md#prompts) +1. [Resources](server.md#resources) +1. [Tools](tools.md) +1. [Utilities](server.md#utilities) + 1. [Completion](server.md#completion) + 1. [Logging](server.md#logging) + 1. [Pagination](server.md#pagination) + +# TroubleShooting + +See [troubleshooting.md](troubleshooting.md) for a troubleshooting guide. diff --git a/docs/client.md b/docs/client.md new file mode 100644 index 00000000..2cbe082c --- /dev/null +++ b/docs/client.md @@ -0,0 +1,171 @@ + +# Support for MCP client features + +1. [Roots](#roots) +1. [Sampling](#sampling) +1. [Elicitation](#elicitation) + +## Roots + +MCP allows clients to specify a set of filesystem +["roots"](https://modelcontextprotocol.io/specification/2025-06-18/client/roots). +The SDK supports this as follows: + +**Client-side**: The SDK client always has the `roots.listChanged` capability. +To add roots to a client, use the +[`Client.AddRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client.AddRoots) +and +[`Client.RemoveRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client.RemoveRoots) +methods. If any servers are already [connected](protocol.md#lifecycle) to the +client, a call to `AddRoot` or `RemoveRoots` will result in a +`notifications/roots/list_changed` notification to each connected server. + +**Server-side**: To query roots from the server, use the +[`ServerSession.ListRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.ListRoots) +method. To receive notifications about root changes, set +[`ServerOptions.RootsListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.RootsListChangedHandler). + +```go +func Example_roots() { + ctx := context.Background() + + // Create a client with a single root. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + c.AddRoots(&mcp.Root{URI: "file://a"}) + + // Now create a server with a handler to receive notifications about roots. + rootsChanged := make(chan struct{}) + handleRootsChanged := func(ctx context.Context, req *mcp.RootsListChangedRequest) { + rootList, err := req.Session.ListRoots(ctx, nil) + if err != nil { + log.Fatal(err) + } + var roots []string + for _, root := range rootList.Roots { + roots = append(roots, root.URI) + } + fmt.Println(roots) + close(rootsChanged) + } + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, &mcp.ServerOptions{ + RootsListChangedHandler: handleRootsChanged, + }) + + // Connect the server and client... + t1, t2 := mcp.NewInMemoryTransports() + if _, err := s.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + + clientSession, err := c.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer clientSession.Close() + + // ...and add a root. The server is notified about the change. + c.AddRoots(&mcp.Root{URI: "file://b"}) + <-rootsChanged + // Output: [file://a file://b] +} +``` + +## Sampling + +[Sampling](https://modelcontextprotocol.io/specification/2025-06-18/client/sampling) +is a way for servers to leverage the client's AI capabilities. It is +implemented in the SDK as follows: + +**Client-side**: To add the `sampling` capability to a client, set +[`ClientOptions.CreateMessageHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.CreateMessageHandler). +This function is invoked whenever the server requests sampling. + +**Server-side**: To use sampling from the server, call +[`ServerSession.CreateMessage`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.CreateMessage). + +```go +func Example_sampling() { + ctx := context.Background() + + // Create a client with a sampling handler. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, &mcp.ClientOptions{ + CreateMessageHandler: func(_ context.Context, req *mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + return &mcp.CreateMessageResult{ + Content: &mcp.TextContent{ + Text: "would have created a message", + }, + }, nil + }, + }) + + // Connect the server and client... + ct, st := mcp.NewInMemoryTransports() + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + session, err := s.Connect(ctx, st, nil) + if err != nil { + log.Fatal(err) + } + defer session.Close() + + if _, err := c.Connect(ctx, ct, nil); err != nil { + log.Fatal(err) + } + + msg, err := session.CreateMessage(ctx, &mcp.CreateMessageParams{}) + if err != nil { + log.Fatal(err) + } + fmt.Println(msg.Content.(*mcp.TextContent).Text) + // Output: would have created a message +} +``` + +## Elicitation + +[Elicitation](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation) +allows servers to request user inputs. It is implemented in the SDK as follows: + +**Client-side**: To add the `elicitation` capability to a client, set +[`ClientOptions.ElicitationHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ElicitationHandler). +The elicitation handler must return a result that matches the requested schema; +otherwise, elicitation returns an error. + +**Server-side**: To use elicitation from the server, call +[`ServerSession.Elicit`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.Elicit). + +```go +func Example_elicitation() { + ctx := context.Background() + ct, st := mcp.NewInMemoryTransports() + + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + log.Fatal(err) + } + defer ss.Close() + + c := mcp.NewClient(testImpl, &mcp.ClientOptions{ + ElicitationHandler: func(context.Context, *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + return &mcp.ElicitResult{Action: "accept", Content: map[string]any{"test": "value"}}, nil + }, + }) + if _, err := c.Connect(ctx, ct, nil); err != nil { + log.Fatal(err) + } + res, err := ss.Elicit(ctx, &mcp.ElicitParams{ + Message: "This should fail", + RequestedSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "test": {Type: "string"}, + }, + }, + }) + if err != nil { + log.Fatal(err) + } + fmt.Println(res.Content["test"]) + // Output: value +} +``` diff --git a/docs/protocol.md b/docs/protocol.md new file mode 100644 index 00000000..c87e522b --- /dev/null +++ b/docs/protocol.md @@ -0,0 +1,456 @@ + +# Support for the MCP base protocol + +1. [Lifecycle](#lifecycle) +1. [Transports](#transports) + 1. [Stdio Transport](#stdio-transport) + 1. [Streamable Transport](#streamable-transport) + 1. [Custom transports](#custom-transports) + 1. [Concurrency](#concurrency) +1. [Authorization](#authorization) + 1. [Server](#server) + 1. [Client](#client) +1. [Security](#security) + 1. [Confused Deputy](#confused-deputy) + 1. [Token Passthrough](#token-passthrough) + 1. [Session Hijacking](#session-hijacking) +1. [Utilities](#utilities) + 1. [Cancellation](#cancellation) + 1. [Ping](#ping) + 1. [Progress](#progress) + +## Lifecycle + +The SDK provides an API for defining both MCP clients and servers, and +connecting them over various transports. When a client and server are +connected, it creates a logical session, which follows the MCP spec's +[lifecycle](https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle). + +In this SDK, both a +[`Client`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client) +and +[`Server`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server) +can handle multiple peers. Every time a new peer is connected, it creates a new +session. + +- A `Client` is a logical MCP client, configured with various + [`ClientOptions`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions). +- When a client is connected to a server using + [`Client.Connect`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client.Connect), + it creates a + [`ClientSession`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession). + This session is initialized during the `Connect` method, and provides methods + to communicate with the server peer. +- A `Server` is a logical MCP server, configured with various + [`ServerOptions`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions). +- When a server is connected to a client using + [`Server.Connect`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.Connect), + it creates a + [`ServerSession`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession). + This session is not initialized until the client sends the + `notifications/initialized` message. Use `ServerOptions.InitializedHandler` + to listen for this event, or just use the session through various feature + handlers (such as a `ToolHandler`). Requests to the server are rejected until + the client has initialized the session. + +Both `ClientSession` and `ServerSession` have a `Close` method to terminate the +session, and a `Wait` method to await session termination by the peer. Typically, +it is the client's responsibility to end the session. + +```go +func Example_lifecycle() { + ctx := context.Background() + + // Create a client and server. + // Wait for the client to initialize the session. + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, &mcp.ServerOptions{ + InitializedHandler: func(context.Context, *mcp.InitializedRequest) { + fmt.Println("initialized!") + }, + }) + + // Connect the server and client using in-memory transports. + // + // Connect the server first so that it's ready to receive initialization + // messages from the client. + t1, t2 := mcp.NewInMemoryTransports() + serverSession, err := server.Connect(ctx, t1, nil) + if err != nil { + log.Fatal(err) + } + clientSession, err := client.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + + // Now shut down the session by closing the client, and waiting for the + // server session to end. + if err := clientSession.Close(); err != nil { + log.Fatal(err) + } + if err := serverSession.Wait(); err != nil { + log.Fatal(err) + } + // Output: initialized! +} +``` + +## Transports + +A +[transport](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports) +can be used to send JSON-RPC messages from client to server, or vice-versa. + +In the SDK, this is achieved by implementing the +[`Transport`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Transport) +interface, which creates a (logical) bidirectional stream of JSON-RPC messages. +Most transport implementations described below are specific to either the +client or server: a "client transport" is something that can be used to connect +a client to a server, and a "server transport" is something that can be used to +connect a server to a client. However, it's possible for a transport to be both +a client and server transport, such as the `InMemoryTransport` used in the +lifecycle example above. + +Transports should not be reused for multiple connections: if you need to create +multiple connections, use different transports. + +### Stdio Transport + +In the +[`stdio`](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#stdio) +transport clients communicate with an MCP server running in a subprocess using +newline-delimited JSON over its stdin/stdout. + +**Client-side**: the client side of the `stdio` transport is implemented by +[`CommandTransport`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#CommandTransport), +which starts the a `exec.Cmd` as a subprocess and communicates over its +stdin/stdout. + +**Server-side**: the server side of the `stdio` transport is implemented by +`StdioTransport`, which connects over the current processes `os.Stdin` and +`os.Stdout`. + +### Streamable Transport + +The [streamable +transport](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http) +API is implemented across three types: + +- `StreamableHTTPHandler`: an`http.Handler` that serves streamable MCP + sessions. +- `StreamableServerTransport`: a `Transport` that implements the server side of + the streamable transport. +- `StreamableClientTransport`: a `Transport` that implements the client side of + the streamable transport. + +To create a streamable MCP server, you create a `StreamableHTTPHandler` and +pass it an `mcp.Server`: + +```go +func ExampleStreamableHTTPHandler() { + // Create a new streamable handler, using the same MCP server for every request. + // + // Here, we configure it to serves application/json responses rather than + // text/event-stream, just so the output below doesn't use random event ids. + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.1.0"}, nil) + handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return server + }, &mcp.StreamableHTTPOptions{JSONResponse: true}) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + // The SDK is currently permissive of some missing keys in "params". + resp := mustPostMessage(`{"jsonrpc": "2.0", "id": 1, "method":"initialize", "params": {}}`, httpServer.URL) + fmt.Println(resp) + // Output: + // {"jsonrpc":"2.0","id":1,"result":{"capabilities":{"logging":{}},"protocolVersion":"2025-06-18","serverInfo":{"name":"server","version":"v0.1.0"}}} +} +``` + +The `StreamableHTTPHandler` handles the HTTP requests and creates a new +`StreamableServerTransport` for each new session. The transport is then used to +communicate with the client. + +On the client side, you create a `StreamableClientTransport` and use it to +connect to the server: + +```go +transport := &mcp.StreamableClientTransport{ + Endpoint: "http://localhost:8080/mcp", +} +client, err := mcp.Connect(ctx, transport, &mcp.ClientOptions{...}) +``` + +The `StreamableClientTransport` handles the HTTP requests and communicates with +the server using the streamable transport protocol. + +#### Stateless Mode + +The streamable server supports a _stateless mode_ by setting +[`StreamableHTTPOptions.Stateless`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#StreamableHTTPOptions.Stateless), +which is where the server does not perform any validation of the session id, +and uses a temporary session to handle requests. In this mode, it is impossible +for the server to make client requests, as there is no way for the client's +response to reach the session. + +However, it is still possible for the server to access the `ServerSession.ID` +to see the logical session + +> [!WARNING] +> Stateless mode is not directly discussed in the spec, and is still being +> defined. See modelcontextprotocol/modelcontextprotocol#1364, +> modelcontextprotocol/modelcontextprotocol#1372, or +> modelcontextprotocol/modelcontextprotocol#11442 for potential refinements. + +_See [examples/server/distributed](../examples/server/distributed/main.go) for +an example using statless mode to implement a server distributed across +multiple processes._ + +### Custom transports + +The SDK supports [custom +transports](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#custom-transports) +by implementing the +[`Transport`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Transport) +interface: a logical bidirectional stream of JSON-RPC messages. + +_Full example: [examples/server/custom-transport](../examples/server/custom-transport/main.go)._ + +### Concurrency + +In general, MCP offers no guarantees about concurrency semantics: if a client +or server sends a notification, the spec says nothing about when the peer +observes that notification relative to other request. However, the Go SDK +implements the following heuristics: + +- If a notifying method (such as `notifications/progress` or + `notifications/initialized`) returns, then it is guaranteed that the peer + observes that notification before other notifications or calls from the same + client goroutine. +- Calls (such as `tools/call`) are handled asynchronously with respect to + each other. + +See +[modelcontextprotocol/go-sdk#26](https://github.com/modelcontextprotocol/go-sdk/issues/26) +for more background. + +## Authorization + +### Server + +To write an MCP server that performs authorization, +use [`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken). +This function is middleware that wraps an HTTP handler, such as the one returned +by [`NewStreamableHTTPHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#NewStreamableHTTPHandler), to provide support for verifying bearer tokens. +The middleware function checks every request for an Authorization header with a bearer token, +and invokes the +[`TokenVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenVerifier) + passed to `RequireBearerToken` to parse the token and perform validation. +The middleware function checks expiration and scopes (if they are provided in +[`RequireBearerTokenOptions.Scopes`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerTokenOptions.Scopes)), so the +`TokenVerifer` doesn't have to. +If [`RequireBearerTokenOptions.ResourceMetadataURL`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerTokenOptions.ResourceMetadataURL) is set and verification fails, +the middleware function sets the WWW-Authenticate header as required by the [Protected Resource +Metadata spec](https://datatracker.ietf.org/doc/html/rfc9728). + +Server handlers, such as tool handlers, can obtain the `TokenInfo` returned by the `TokenVerifier` +from `req.Extra.TokenInfo`, where `req` is the handler's request. (For example, a +[`CallToolRequest`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#CallToolRequest).) +HTTP handlers wrapped by the `RequireBearerToken` middleware can obtain the `TokenInfo` from the context +with [`auth.TokenInfoFromContext`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenInfoFromContext). + + +The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/tree/main/examples/server/auth-middleware) shows how to implement authorization for both JWT tokens and API keys. + +### Client + +Client-side OAuth is implemented by setting +[`StreamableClientTransport.HTTPClient`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk@v0.5.0/mcp#StreamableClientTransport.HTTPClient) to a custom [`http.Client`](https://pkg.go.dev/net/http#Client) +Additional support is forthcoming; see #493. + +## Security + +Here we discuss the mitigations described under +the MCP spec's [Security Best Practices](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices) section, and how we handle them. + +### Confused Deputy + +The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), obtaining user consent for dynamically registered clients, +happens on the MCP client. At present we don't provide client-side OAuth support. + + +### Token Passthrough + +The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-2), accepting only tokens that were issued for the server, depends on the structure +of tokens and is the responsibility of the +[`TokenVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenVerifier) +provided to +[`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken). + +### Session Hijacking + +The [mitigations](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-3) are as follows: + +- _Verify all inbound requests_. The [`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken) +middleware function will verify all HTTP requests that it receives. It is the +user's responsibility to wrap that function around all handlers in their server. + +- _Secure session IDs_. This SDK generates cryptographically secure session IDs by default. +If you create your own with +[`ServerOptions.GetSessionID`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.GetSessionID), it is your responsibility to ensure they are secure. +If you are using Go 1.24 or above, +we recommend using [`crypto/rand.Text`](https://pkg.go.dev/crypto/rand#Text) + +- _Binding session IDs to user information_. This is an application requirement, out of scope +for the SDK. You can create your own session IDs by setting +[`ServerOptions.GetSessionID`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.GetSessionID). + +## Utilities + +### Cancellation + +Cancellation is implemented with context cancellation. Cancelling a context +used in a method on `ClientSession` or `ServerSession` will terminate the RPC +and send a "notifications/cancelled" message to the peer. + +When an RPC exits due to a cancellation error, there's a guarantee that the +cancellation notification has been sent, but there's no guarantee that the +server has observed it (see [concurrency](#concurrency)). + +```go +func Example_cancellation() { + // For this example, we're going to be collecting observations from the + // server and client. + var clientResult, serverResult string + var wg sync.WaitGroup + wg.Add(2) + + // Create a server with a single slow tool. + // When the client cancels its request, the server should observe + // cancellation. + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + started := make(chan struct{}, 1) // signals that the server started handling the tool call + mcp.AddTool(server, &mcp.Tool{Name: "slow"}, func(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + started <- struct{}{} + defer wg.Done() + select { + case <-time.After(5 * time.Second): + serverResult = "tool done" + case <-ctx.Done(): + serverResult = "tool canceled" + } + return &mcp.CallToolResult{}, nil, nil + }) + + // Connect a client to the server. + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + ctx := context.Background() + t1, t2 := mcp.NewInMemoryTransports() + if _, err := server.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + session, err := client.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer session.Close() + + // Make a tool call, asynchronously. + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer wg.Done() + _, err = session.CallTool(ctx, &mcp.CallToolParams{Name: "slow"}) + clientResult = fmt.Sprintf("%v", err) + }() + + // As soon as the server has started handling the call, cancel it from the + // client side. + <-started + cancel() + wg.Wait() + + fmt.Println(clientResult) + fmt.Println(serverResult) + // Output: + // context canceled + // tool canceled +} +``` + +### Ping + +[Ping](https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/ping) +support is symmetrical for client and server. + +To initiate a ping, call +[`ClientSession.Ping`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Ping) +or +[`ServerSession.Ping`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.Ping). + +To have the client or server session automatically ping its peer, and close the +session if the ping fails, set +[`ClientOptions.KeepAlive`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.KeepAlive) +or +[`ServerOptions.KeepAlive`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.KeepAlive). + +### Progress + +[Progress](https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/progress) +reporting is possible by reading the progress token from request metadata and +calling either +[`ClientSession.NotifyProgress`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.NotifyProgress) +or +[`ServerSession.NotifyProgress`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.NotifyProgress). +To listen to progress notifications, set +[`ClientOptions.ProgressNotificationHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ProgressNotificationHandler) +or +[`ServerOptions.ProgressNotificationHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.ProgressNotificationHandler). + +Issue #460 discusses some potential ergonomic improvements to this API. + +```go +func Example_progress() { + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "makeProgress"}, func(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + if token := req.Params.GetProgressToken(); token != nil { + for i := range 3 { + params := &mcp.ProgressNotificationParams{ + Message: "frobbing widgets", + ProgressToken: token, + Progress: float64(i), + Total: 2, + } + req.Session.NotifyProgress(ctx, params) // ignore error + } + } + return &mcp.CallToolResult{}, nil, nil + }) + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, &mcp.ClientOptions{ + ProgressNotificationHandler: func(_ context.Context, req *mcp.ProgressNotificationClientRequest) { + fmt.Printf("%s %.0f/%.0f\n", req.Params.Message, req.Params.Progress, req.Params.Total) + }, + }) + ctx := context.Background() + t1, t2 := mcp.NewInMemoryTransports() + if _, err := server.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + + session, err := client.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer session.Close() + if _, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "makeProgress", + Meta: mcp.Meta{"progressToken": "abc123"}, + }); err != nil { + log.Fatal(err) + } + // Output: + // frobbing widgets 0/2 + // frobbing widgets 1/2 + // frobbing widgets 2/2 +} +``` diff --git a/docs/server.md b/docs/server.md new file mode 100644 index 00000000..d6656ee9 --- /dev/null +++ b/docs/server.md @@ -0,0 +1,564 @@ + +# Support for MCP server features + +1. [Prompts](#prompts) +1. [Resources](#resources) +1. [Tools](#tools) +1. [Utilities](#utilities) + 1. [Completion](#completion) + 1. [Logging](#logging) + 1. [Pagination](#pagination) + +## Prompts + +MCP servers can provide LLM prompt templates (called simply +[_prompts_](https://modelcontextprotocol.io/specification/2025-06-18/server/prompts)) +to clients. Every prompt has a required name which identifies it, and a set of +named arguments, which are strings. + +**Client-side**: To list the server's prompts, use the +[`ClientSession.Prompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Prompts) +iterator, or the lower-level +[`ClientSession.ListPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ListPrompts) +(see [pagination](#pagination) below). Set +[`ClientOptions.PromptListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.PromptListChangedHandler) +to be notified of changes in the list of prompts. + +Call +[`ClientSession.GetPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.GetPrompt) +to retrieve a prompt by name, providing arguments for expansion. + +**Server-side**: Use +[`Server.AddPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddPrompt) +to add a prompt to the server along with its handler. +The server will have the `prompts` capability if any prompt is added before the +server is connected to a client, or if +[`ServerOptions.HasPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasPrompts) +is explicitly set. When a prompt is added, any clients already connected to the +server will be notified via a `notifications/prompts/list_changed` +notification. + +```go +func Example_prompts() { + ctx := context.Background() + + promptHandler := func(ctx context.Context, req *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return &mcp.GetPromptResult{ + Description: "Hi prompt", + Messages: []*mcp.PromptMessage{ + { + Role: "user", + Content: &mcp.TextContent{Text: "Say hi to " + req.Params.Arguments["name"]}, + }, + }, + }, nil + } + + // Create a server with a single prompt. + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + prompt := &mcp.Prompt{ + Name: "greet", + Arguments: []*mcp.PromptArgument{ + { + Name: "name", + Description: "the name of the person to greet", + Required: true, + }, + }, + } + s.AddPrompt(prompt, promptHandler) + + // Create a client. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + + // Connect the server and client. + t1, t2 := mcp.NewInMemoryTransports() + if _, err := s.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + cs, err := c.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + // List the prompts. + for p, err := range cs.Prompts(ctx, nil) { + if err != nil { + log.Fatal(err) + } + fmt.Println(p.Name) + } + + // Get the prompt. + res, err := cs.GetPrompt(ctx, &mcp.GetPromptParams{ + Name: "greet", + Arguments: map[string]string{"name": "Pat"}, + }) + if err != nil { + log.Fatal(err) + } + for _, msg := range res.Messages { + fmt.Println(msg.Role, msg.Content.(*mcp.TextContent).Text) + } + // Output: + // greet + // user Say hi to Pat +} +``` + +## Resources + +In MCP terms, a _resource_ is some data referenced by a URI. +MCP servers can serve resources to clients. +They can register resources individually, or register a _resource template_ +that uses a URI pattern to describe a collection of resources. + + +**Client-side**: +Call [`ClientSession.ReadResource`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ReadResource) +to read a resource. +The SDK ensures that a read succeeds only if the URI matches a registered resource exactly, +or matches the URI pattern of a resource template. + +To list a server's resources and resource templates, use the +[`ClientSession.Resources`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Resources) +and +[`ClientSession.ResourceTemplates`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ResourceTemplates) +iterators, or the lower-level `ListXXX` calls (see [pagination](#pagination)). +Set +[`ClientOptions.ResourceListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ResourceListChangedHandler) +to be notified of changes in the lists of resources or resource templates. + +Clients can be notified when the contents of a resource changes by subscribing to the resource's URI. +Call +[`ClientSession.Subscribe`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Subscribe) +to subscribe to a resource +and +[`ClientSession.Unsubscribe`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Unsubscribe) +to unsubscribe. +Set +[`ClientOptions.ResourceUpdatedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ResourceUpdatedHandler) +to be notified of changes to subscribed resources. + +**Server-side**: +Use +[`Server.AddResource`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddResource) +or +[`Server.AddResourceTemplate`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddResourceTemplate) +to add a resource or resource template to the server along with its handler. +A +[`ResourceHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ResourceHandler) +maps a URI to the contents of a resource, which can include text, binary data, +or both. +If `AddResource` or `AddResourceTemplate` is called before a server is connected, the server will have the +`resources` capability. +The server will have the `resources` capability if any resource or resource template is added before the +server is connected to a client, or if +[`ServerOptions.HasResources`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasResources) +is explicitly set. When a prompt is added, any clients already connected to the +server will be notified via a `notifications/resources/list_changed` +notification. + + +```go +func Example_resources() { + ctx := context.Background() + + resources := map[string]string{ + "file:///a": "a", + "file:///dir/x": "x", + "file:///dir/y": "y", + } + + handler := func(_ context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + uri := req.Params.URI + c, ok := resources[uri] + if !ok { + return nil, mcp.ResourceNotFoundError(uri) + } + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{URI: uri, Text: c}}, + }, nil + } + + // Create a server with a single resource. + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + s.AddResource(&mcp.Resource{URI: "file:///a"}, handler) + s.AddResourceTemplate(&mcp.ResourceTemplate{URITemplate: "file:///dir/{f}"}, handler) + + // Create a client. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + + // Connect the server and client. + t1, t2 := mcp.NewInMemoryTransports() + if _, err := s.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + cs, err := c.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + // List resources and resource templates. + for r, err := range cs.Resources(ctx, nil) { + if err != nil { + log.Fatal(err) + } + fmt.Println(r.URI) + } + for r, err := range cs.ResourceTemplates(ctx, nil) { + if err != nil { + log.Fatal(err) + } + fmt.Println(r.URITemplate) + } + + // Read resources. + for _, path := range []string{"a", "dir/x", "b"} { + res, err := cs.ReadResource(ctx, &mcp.ReadResourceParams{URI: "file:///" + path}) + if err != nil { + fmt.Println(err) + } else { + fmt.Println(res.Contents[0].Text) + } + } + // Output: + // file:///a + // file:///dir/{f} + // a + // x + // calling "resources/read": Resource not found +} +``` + +## Tools + +MCP servers can provide +[tools](https://modelcontextprotocol.io/specification/2025-06-18/server/tools) +to allow clients to interact with external systems or functionality. Tools are +effectively remote function calls, and the Go SDK provides mechanisms to bind +them to ordinary Go functions. + +**Client-side**: To list the server's tools, use the +[`ClientSession.Tools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Tools) +iterator, or the lower-level +[`ClientSession.ListTools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ListTools) +(see [pagination](#pagination)). Set +[`ClientOptions.ToolListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ToolListChangedHandler) +to be notified of changes in the list of tools. + +To call a tool, use +[`ClientSession.CallTool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.CallTool) +with `CallToolParams` holding the name and arguments of the tool to call. + +```go +res, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "my_tool", + Arguments: map[string]any{"name": "user"}, +}) +``` + +Arguments may be any value that can be marshaled to JSON. + +**Server-side**: the basic API for adding a tool is symmetrical with the API +for prompts or resources: +[`Server.AddTool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddTool) +adds a +[`Tool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Tool) to +the server along with its +[`ToolHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ToolHandler) +to handle it. The server will have the `tools` capability if any tool is added +before the server is connected to a client, or if +[`ServerOptions.HasTools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasPrompts) +is explicitly set. When a tool is added, any clients already connected to the +server will be notified via a `notifications/tools/list_changed` notification. + +However, the `Server.AddTool` API leaves it to the user to implement the tool +handler correctly according to the spec, providing very little out of the box. +In order to implement a tool, the user must do all of the following: + +- Provide a tool input and output schema. +- Validate the tool arguments against its input schema. +- Unmarshal the input schema into a Go value +- Execute the tool logic. +- Marshal the tool's structured output (if any) to JSON, and store it in the + result's `StructuredOutput` field as well as the unstructured `Content` field. +- Validate that output JSON against the tool's output schema. +- If any tool errors occurred, pack them into the unstructured content and set + `IsError` to `true.` + +For this reason, the SDK provides a generic +[`AddTool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#AddTool) +function that handles this for you. It can bind a tool to any function with the +following shape: + +```go +func(_ context.Context, request *CallToolRequest, input In) (result *CallToolResult, output Out, _ error) +``` + +This is like a `ToolHandler`, but with an extra arbitrary `In` input parameter, +and `Out` output parameter. + +Such a function can then be bound to the server using `AddTool`: + +```go +mcp.AddTool(server, &mcp.Tool{Name: "my_tool"}, handler) +``` + +This does the following automatically: + +- If `Tool.InputSchema` or `Tool.OutputSchema` are unset, the input and output + schemas are inferred from the `In` type, which must be a struct or map. + Optional `jsonschema` struct tags provide argument descriptions. +- Tool arguments are validated against the input schema. +- Tool arguments are marshaled into the `In` value. +- Tool output (the `Out` value) is marshaled into the result's + `StructuredOutput`, as well as the unstructured `Content`. +- Output is validated against the tool's output schema. +- If an ordinary error is returned, it is stored int the `CallToolResult` and + `IsError` is set to `true`. + +In fact, under ordinary circumstances, the user can ignore `CallToolRequest` +and `CallToolResult`. + +For a more realistic example, consider a tool that retrieves the weather: + +```go +type WeatherInput struct { + Location Location `json:"location" jsonschema:"user location"` + Days int `json:"days" jsonschema:"number of days to forecast"` +} + +type WeatherOutput struct { + Summary string `json:"summary" jsonschema:"a summary of the weather forecast"` + Confidence Probability `json:"confidence" jsonschema:"confidence, between 0 and 1"` + AsOf time.Time `json:"asOf" jsonschema:"the time the weather was computed"` + DailyForecast []Forecast `json:"dailyForecast" jsonschema:"the daily forecast"` + Source string `json:"source,omitempty" jsonschema:"the organization providing the weather forecast"` +} + +func WeatherTool(ctx context.Context, req *mcp.CallToolRequest, in WeatherInput) (*mcp.CallToolResult, WeatherOutput, error) { + perfectWeather := WeatherOutput{ + Summary: "perfect", + Confidence: 1.0, + AsOf: time.Now(), + } + for range in.Days { + perfectWeather.DailyForecast = append(perfectWeather.DailyForecast, Forecast{ + Forecast: "another perfect day", + Type: Sunny, + Rain: 0.0, + High: 72.0, + Low: 72.0, + }) + } + return nil, perfectWeather, nil +} +``` + +In this case, we want to customize part of the inferred schema, though we can +still infer the rest. Since we want to control the inference ourselves, we set +the `Tool.InputSchema` explicitly: + +```go +// Distinguished Go types allow custom schemas to be reused during inference. +customSchemas := map[any]*jsonschema.Schema{ + Probability(0): {Type: "number", Minimum: jsonschema.Ptr(0.0), Maximum: jsonschema.Ptr(1.0)}, + WeatherType(""): {Type: "string", Enum: []any{Sunny, PartlyCloudy, Cloudy, Rainy, Snowy}}, +} +opts := &jsonschema.ForOptions{TypeSchemas: customSchemas} +in, err := jsonschema.For[WeatherInput](opts) +if err != nil { + log.Fatal(err) +} + +// Furthermore, we can tweak the inferred schema, in this case limiting +// forecasts to 0-10 days. +daysSchema := in.Properties["days"] +daysSchema.Minimum = jsonschema.Ptr(0.0) +daysSchema.Maximum = jsonschema.Ptr(10.0) + +// Output schema inference can reuse our custom schemas from input inference. +out, err := jsonschema.For[WeatherOutput](opts) +if err != nil { + log.Fatal(err) +} + +// Now add our tool to a server. Since we've customized the schemas, we need +// to override the default schema inference. +server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) +mcp.AddTool(server, &mcp.Tool{ + Name: "weather", + InputSchema: in, + OutputSchema: out, +}, WeatherTool) +``` + +_See [mcp/tool_example_test.go](../mcp/tool_example_test.go) for the full +example, or [examples/server/toolschemas](examples/server/toolschemas/main.go) +for more examples of customizing tool schemas._ + +## Utilities + +### Completion + +To support the +[completion](https://modelcontextprotocol.io/specification/2025-06-18/server/utilities/completion) +capability, the server needs a completion handler. + +**Client-side**: completion is called using the +[`ClientSession.Complete`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Complete) +method. + +**Server-side**: completion is enabled by setting +[`ServerOptions.CompletionHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.CompletionHandler). +If this field is set to a non-nil value, the server will advertise the +`completions` server capability, and use this handler to respond to completion +requests. + +```go +myCompletionHandler := func(_ context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { + // In a real application, you'd implement actual completion logic here. + // For this example, we return a fixed set of suggestions. + var suggestions []string + switch req.Params.Ref.Type { + case "ref/prompt": + suggestions = []string{"suggestion1", "suggestion2", "suggestion3"} + case "ref/resource": + suggestions = []string{"suggestion4", "suggestion5", "suggestion6"} + default: + return nil, fmt.Errorf("unrecognized content type %s", req.Params.Ref.Type) + } + + return &mcp.CompleteResult{ + Completion: mcp.CompletionResultDetails{ + HasMore: false, + Total: len(suggestions), + Values: suggestions, + }, + }, nil +} + +// Create the MCP Server instance and assign the handler. +// No server running, just showing the configuration. +_ = mcp.NewServer(&mcp.Implementation{Name: "server"}, &mcp.ServerOptions{ + CompletionHandler: myCompletionHandler, +}) +``` + +### Logging + +MCP servers can send logging messages to MCP clients. +(This form of logging is distinct from server-side logging, where the +server produces logs that remain server-side, for use by server maintainers.) + +**Server-side**: +The minimum log level is part of the server state. +For stateful sessions, there is no default log level: no log messages will be sent +until the client calls `SetLevel` (see below). +For stateful sessions, the level defaults to "info". + +[`ServerSession.Log`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.Log) is the low-level way for servers to log to clients. +It sends a logging notification to the client if the level of the message +is at least the minimum log level. + +For a simpler API, use [`NewLoggingHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#NewLoggingHandler) to obtain a [`slog.Handler`](https://pkg.go.dev/log/slog#Handler). +By setting [`LoggingHandlerOptions.MinInterval`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#LoggingHandlerOptions.MinInterval), the handler can be rate-limited +to avoid spamming clients with too many messages. + +Servers always report the logging capability. + + +**Client-side**: +Set [`ClientOptions.LoggingMessageHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.LoggingMessageHandler) to receive log messages. + +Call [`ClientSession.SetLevel`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.SetLevel) to change the log level for a session. + +```go +func Example_logging() { + ctx := context.Background() + + // Create a server. + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + + // Create a client that displays log messages. + done := make(chan struct{}) // solely for the example + var nmsgs atomic.Int32 + c := mcp.NewClient( + &mcp.Implementation{Name: "client", Version: "v0.0.1"}, + &mcp.ClientOptions{ + LoggingMessageHandler: func(_ context.Context, r *mcp.LoggingMessageRequest) { + m := r.Params.Data.(map[string]any) + fmt.Println(m["msg"], m["value"]) + if nmsgs.Add(1) == 2 { // number depends on logger calls below + close(done) + } + }, + }) + + // Connect the server and client. + t1, t2 := mcp.NewInMemoryTransports() + ss, err := s.Connect(ctx, t1, nil) + if err != nil { + log.Fatal(err) + } + defer ss.Close() + cs, err := c.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + // Set the minimum log level to "info". + if err := cs.SetLoggingLevel(ctx, &mcp.SetLoggingLevelParams{Level: "info"}); err != nil { + log.Fatal(err) + } + + // Get a slog.Logger for the server session. + logger := slog.New(mcp.NewLoggingHandler(ss, nil)) + + // Log some things. + logger.Info("info shows up", "value", 1) + logger.Debug("debug doesn't show up", "value", 2) + logger.Warn("warn shows up", "value", 3) + + // Wait for them to arrive on the client. + // In a real application, the log messages would appear asynchronously + // while other work was happening. + <-done + + // Output: + // info shows up 1 + // warn shows up 3 +} +``` + +### Pagination + +Server-side feature lists may be +[paginated](https://modelcontextprotocol.io/specification/2025-06-18/server/utilities/pagination), +using cursors. The SDK supports this by default. + +**Client-side**: The `ClientSession` provides methods returning +[iterators](https://go.dev/blog/range-functions) for each feature type. +These iterators are an `iter.Seq2[Feature, error]`, where the error value +indicates whether page retrieval failed. + +- [`ClientSession.Prompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Prompts) + iterates prompts. +- [`ClientSession.Resource`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Resource) + iterates resources. +- [`ClientSession.ResourceTemplates`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ResourceTemplates) + iterates resource templates. +- [`ClientSession.Tools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Tools) + iterates tools. + +The `ClientSession` also exposes `ListXXX` methods for fine-grained control +over pagination. + +**Server-side**: pagination is on by default, so in general nothing is required +server-side. However, you may use +[`ServerOptions.PageSize`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.PageSize) +to customize the page size. diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md new file mode 100644 index 00000000..38410ad5 --- /dev/null +++ b/docs/troubleshooting.md @@ -0,0 +1,98 @@ + +# Troubleshooting + +The Model Context Protocol is a complicated spec that leaves some room for +interpretation. Client and server SDKs can behave differently, or can be more +or less strict about their inputs. And of course, bugs happen. + +When you encounter a problem using the Go SDK, these instructions can help +collect information that will be useful in debugging. Please try to provide +this information in a bug report, so that maintainers can more quickly +understand what's going wrong. + +And most of all, please do [file bugs](https://github.com/modelcontextprotocol/go-sdk/issues/new?template=bug_report.md). + +## Using the MCP inspector + +To debug an MCP server, you can use the [MCP +inspector](https://modelcontextprotocol.io/legacy/tools/inspector). This is +useful for testing your server and verifying that it works with the typescript +SDK, as well as inspecting MCP traffic. + +## Collecting MCP logs + +For [stdio](protocol.md#stdio-transport) transport connections, you can also +inspect MCP traffic using a `LoggingTransport`: + +```go +func ExampleLoggingTransport() { + ctx := context.Background() + t1, t2 := mcp.NewInMemoryTransports() + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + serverSession, err := server.Connect(ctx, t1, nil) + if err != nil { + log.Fatal(err) + } + defer serverSession.Wait() + + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + var b bytes.Buffer + logTransport := &mcp.LoggingTransport{Transport: t2, Writer: &b} + clientSession, err := client.Connect(ctx, logTransport, nil) + if err != nil { + log.Fatal(err) + } + defer clientSession.Close() + + // Sort for stability: reads are concurrent to writes. + for _, line := range slices.Sorted(strings.SplitSeq(b.String(), "\n")) { + fmt.Println(line) + } + + // Output: + // read: {"jsonrpc":"2.0","id":1,"result":{"capabilities":{"logging":{}},"protocolVersion":"2025-06-18","serverInfo":{"name":"server","version":"v0.0.1"}}} + // write: {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"capabilities":{"roots":{"listChanged":true}},"clientInfo":{"name":"client","version":"v0.0.1"},"protocolVersion":"2025-06-18"}} + // write: {"jsonrpc":"2.0","method":"notifications/initialized","params":{}} +} +``` + +That example uses a `bytes.Buffer`, but you can also log to a file, or to +`os.Stderr`. + +## Inspecting HTTP traffic + +There are a couple different ways to investigate traffic to an HTTP transport +([streamable](protocol.md#streamable-transport) or legacy SSE). + +The first is to use an HTTP middleware: + +```go +func ExampleStreamableHTTPHandler_middleware() { + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.1.0"}, nil) + handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return server + }, nil) + loggingHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Example debugging; you could also capture the response. + body, err := io.ReadAll(req.Body) + if err != nil { + log.Fatal(err) + } + req.Body.Close() // ignore error + req.Body = io.NopCloser(bytes.NewBuffer(body)) + fmt.Println(req.Method, string(body)) + handler.ServeHTTP(w, req) + }) + httpServer := httptest.NewServer(loggingHandler) + defer httpServer.Close() + + // The SDK is currently permissive of some missing keys in "params". + mustPostMessage(`{"jsonrpc": "2.0", "id": 1, "method":"initialize", "params": {}}`, httpServer.URL) + // Output: + // POST {"jsonrpc": "2.0", "id": 1, "method":"initialize", "params": {}} +} +``` + +The second is to use a general purpose tool to inspect http traffic, such as +[wireshark](https://www.wireshark.org/) or +[tcpdump](https://linux.die.net/man/8/tcpdump). diff --git a/examples/client/listfeatures/main.go b/examples/client/listfeatures/main.go new file mode 100644 index 00000000..9d473f0b --- /dev/null +++ b/examples/client/listfeatures/main.go @@ -0,0 +1,65 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The listfeatures command lists all features of a stdio MCP server. +// +// Usage: listfeatures [] +// +// For example: +// +// listfeatures go run github.com/modelcontextprotocol/go-sdk/examples/server/hello +// +// or +// +// listfeatures npx @modelcontextprotocol/server-everything +package main + +import ( + "context" + "flag" + "fmt" + "iter" + "log" + "os" + "os/exec" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func main() { + flag.Parse() + args := flag.Args() + if len(args) == 0 { + fmt.Fprintln(os.Stderr, "Usage: listfeatures []") + fmt.Fprintln(os.Stderr, "List all features for a stdio MCP server") + fmt.Fprintln(os.Stderr) + fmt.Fprintln(os.Stderr, "Example:\n\tlistfeatures npx @modelcontextprotocol/server-everything") + os.Exit(2) + } + + ctx := context.Background() + cmd := exec.Command(args[0], args[1:]...) + client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) + cs, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + printSection("tools", cs.Tools(ctx, nil), func(t *mcp.Tool) string { return t.Name }) + printSection("resources", cs.Resources(ctx, nil), func(r *mcp.Resource) string { return r.Name }) + printSection("resource templates", cs.ResourceTemplates(ctx, nil), func(r *mcp.ResourceTemplate) string { return r.Name }) + printSection("prompts", cs.Prompts(ctx, nil), func(p *mcp.Prompt) string { return p.Name }) +} + +func printSection[T any](name string, features iter.Seq2[T, error], featName func(T) string) { + fmt.Printf("%s:\n", name) + for feat, err := range features { + if err != nil { + log.Fatal(err) + } + fmt.Printf("\t%s\n", featName(feat)) + } + fmt.Println() +} diff --git a/examples/client/loadtest/main.go b/examples/client/loadtest/main.go new file mode 100644 index 00000000..d5a04c2e --- /dev/null +++ b/examples/client/loadtest/main.go @@ -0,0 +1,122 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The load command load tests a streamable MCP server +// +// Usage: loadtest +// +// For example: +// +// loadtest -tool=greet -args='{"name": "foo"}' http://localhost:8080 +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "log" + "os" + "os/signal" + "sync" + "sync/atomic" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + duration = flag.Duration("duration", 1*time.Minute, "duration of the load test") + tool = flag.String("tool", "", "tool to call") + jsonArgs = flag.String("args", "", "JSON arguments to pass") + workers = flag.Int("workers", 10, "number of concurrent workers") + timeout = flag.Duration("timeout", 1*time.Second, "request timeout") + qps = flag.Int("qps", 100, "tool calls per second, per worker") + verbose = flag.Bool("v", false, "if set, enable verbose logging") +) + +func main() { + flag.Usage = func() { + out := flag.CommandLine.Output() + fmt.Fprintf(out, "Usage: loadtest [flags] ") + fmt.Fprintf(out, "Load test a streamable HTTP server (CTRL-C to end early)") + fmt.Fprintln(out) + fmt.Fprintf(out, "Example: loadtest -tool=greet -args='{\"name\": \"foo\"}' http://localhost:8080\n") + fmt.Fprintln(out) + fmt.Fprintln(out, "Flags:") + flag.PrintDefaults() + } + flag.Parse() + args := flag.Args() + if len(args) != 1 || *tool == "" { + flag.Usage() + os.Exit(2) + } + + parentCtx, cancel := context.WithTimeout(context.Background(), *duration) + defer cancel() + parentCtx, stop := signal.NotifyContext(parentCtx, os.Interrupt) + defer stop() + + var ( + start = time.Now() + success atomic.Int64 + failure atomic.Int64 + ) + + // Run the test. + var wg sync.WaitGroup + for range *workers { + wg.Add(1) + go func() { + defer wg.Done() + client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) + cs, err := client.Connect(parentCtx, &mcp.StreamableClientTransport{Endpoint: args[0]}, nil) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + ticker := time.NewTicker(1 * time.Second / time.Duration(*qps)) + defer ticker.Stop() + + for range ticker.C { + ctx, cancel := context.WithTimeout(parentCtx, *timeout) + defer cancel() + + res, err := cs.CallTool(ctx, &mcp.CallToolParams{Name: *tool, Arguments: json.RawMessage(*jsonArgs)}) + if err != nil { + if parentCtx.Err() != nil { + return // test ended + } + failure.Add(1) + if *verbose { + log.Printf("FAILURE: %v", err) + } + } else { + success.Add(1) + if *verbose { + data, err := json.Marshal(res) + if err != nil { + log.Fatalf("marshalling result: %v", err) + } + log.Printf("SUCCESS: %s", string(data)) + } + } + } + }() + } + wg.Wait() + stop() // restore the interrupt signal + + // Print stats. + var ( + dur = time.Since(start) + succ = success.Load() + fail = failure.Load() + ) + fmt.Printf("Results (in %s):\n", dur) + fmt.Printf("\tsuccess: %d (%g QPS)\n", succ, float64(succ)/dur.Seconds()) + fmt.Printf("\tfailure: %d (%g QPS)\n", fail, float64(fail)/dur.Seconds()) +} diff --git a/mcp/example_progress_test.go b/examples/client/middleware/main.go similarity index 50% rename from mcp/example_progress_test.go rename to examples/client/middleware/main.go index 902b2347..9b6d1bb3 100644 --- a/mcp/example_progress_test.go +++ b/examples/client/middleware/main.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. -package mcp_test +package main import ( "context" @@ -15,17 +15,16 @@ var nextProgressToken atomic.Int64 // This middleware function adds a progress token to every outgoing request // from the client. -func Example_progressMiddleware() { - c := mcp.NewClient("test", "v1", nil) - c.AddSendingMiddleware(addProgressToken[*mcp.ClientSession]) - _ = c +func main() { + c := mcp.NewClient(&mcp.Implementation{Name: "test"}, nil) + c.AddSendingMiddleware(addProgressToken) } -func addProgressToken[S mcp.Session](h mcp.MethodHandler[S]) mcp.MethodHandler[S] { - return func(ctx context.Context, s S, method string, params mcp.Params) (result mcp.Result, err error) { - if rp, ok := params.(mcp.RequestParams); ok { +func addProgressToken(h mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) { + if rp, ok := req.GetParams().(mcp.RequestParams); ok { rp.SetProgressToken(nextProgressToken.Add(1)) } - return h(ctx, s, method, params) + return h(ctx, method, req) } } diff --git a/examples/hello/main.go b/examples/hello/main.go deleted file mode 100644 index 9af34cc3..00000000 --- a/examples/hello/main.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package main - -import ( - "context" - "flag" - "fmt" - "log" - "net/http" - "net/url" - "os" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -var httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") - -type HiArgs struct { - Name string `json:"name"` -} - -func SayHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[HiArgs]) (*mcp.CallToolResultFor[struct{}], error) { - return &mcp.CallToolResultFor[struct{}]{ - Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + params.Arguments.Name}, - }, - }, nil -} - -func PromptHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.GetPromptParams) (*mcp.GetPromptResult, error) { - return &mcp.GetPromptResult{ - Description: "Code review prompt", - Messages: []*mcp.PromptMessage{ - {Role: "user", Content: &mcp.TextContent{Text: "Say hi to " + params.Arguments["name"]}}, - }, - }, nil -} - -func main() { - flag.Parse() - - server := mcp.NewServer("greeter", "v0.0.1", nil) - server.AddTools(mcp.NewServerTool("greet", "say hi", SayHi, mcp.Input( - mcp.Property("name", mcp.Description("the name to say hi to")), - ))) - server.AddPrompts(&mcp.ServerPrompt{ - Prompt: &mcp.Prompt{Name: "greet"}, - Handler: PromptHi, - }) - server.AddResources(&mcp.ServerResource{ - Resource: &mcp.Resource{ - Name: "info", - MIMEType: "text/plain", - URI: "embedded:info", - }, - Handler: handleEmbeddedResource, - }) - - if *httpAddr != "" { - handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { - return server - }, nil) - log.Printf("MCP handler listening at %s", *httpAddr) - http.ListenAndServe(*httpAddr, handler) - } else { - t := mcp.NewLoggingTransport(mcp.NewStdioTransport(), os.Stderr) - if err := server.Run(context.Background(), t); err != nil { - log.Printf("Server failed: %v", err) - } - } -} - -var embeddedResources = map[string]string{ - "info": "This is the hello example server.", -} - -func handleEmbeddedResource(_ context.Context, _ *mcp.ServerSession, params *mcp.ReadResourceParams) (*mcp.ReadResourceResult, error) { - u, err := url.Parse(params.URI) - if err != nil { - return nil, err - } - if u.Scheme != "embedded" { - return nil, fmt.Errorf("wrong scheme: %q", u.Scheme) - } - key := u.Opaque - text, ok := embeddedResources[key] - if !ok { - return nil, fmt.Errorf("no embedded resource named %q", key) - } - return &mcp.ReadResourceResult{ - Contents: []*mcp.ResourceContents{ - {URI: params.URI, MIMEType: "text/plain", Text: text}, - }, - }, nil -} diff --git a/examples/http/README.md b/examples/http/README.md new file mode 100644 index 00000000..16a4801d --- /dev/null +++ b/examples/http/README.md @@ -0,0 +1,66 @@ +# MCP HTTP Example + +This example demonstrates how to use the Model Context Protocol (MCP) over HTTP using the streamable transport. It includes both a server and client implementation. + +## Overview + +The example implements: +- A server that provides a `cityTime` tool +- A client that connects to the server, lists available tools, and calls the `cityTime` tool + +## Usage + +Start the Server + +```bash +go run main.go server +``` +This starts an MCP server on `http://localhost:8080` (default) that provides a `cityTime` tool. + +To run a client in another terminal: + +```bash +go run main.go client +``` + +The client will: +1. Connect to the server +2. List available tools +3. Call the `cityTime` tool for NYC, San Francisco, and Boston +4. Display the results + +At any given time you can pass a custom URL to the program to run it on a custom host/port: + +``` +go run main.go -host 0.0.0.0 -port 9000 server +``` + +## Testing with real-world MCP Clients + +Once the server is started, assuming it's the default +localhost:8080, you can try to add it to a popular MCP client: + + claude mcp add -t http timezone http://localhost:8080 + +Once added, Claude Code will be able to discover and use the `cityTime` tool provided by this server. + +In Claude Code: + + > what's the timezone + + ⏺ I'll get the current time in a major US city for you. + + ⏺ timezone - cityTime (MCP)(city: "nyc") + ⎿ The current time in New York City is 7:30:16 PM EDT on Wedn + esday, July 23, 2025 + + + ⏺ The current timezone is EDT (Eastern Daylight Time), and it's + 7:30 PM on Wednesday, July 23, 2025. + + > what timezones do you support? + + ⏺ The timezone tool supports three US cities: + - NYC (Eastern Time) + - SF (Pacific Time) + - Boston (Eastern Time) diff --git a/examples/http/logging_middleware.go b/examples/http/logging_middleware.go new file mode 100644 index 00000000..4266012c --- /dev/null +++ b/examples/http/logging_middleware.go @@ -0,0 +1,51 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "log" + "net/http" + "time" +) + +// responseWriter wraps http.ResponseWriter to capture the status code. +type responseWriter struct { + http.ResponseWriter + statusCode int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.statusCode = code + rw.ResponseWriter.WriteHeader(code) +} + +func loggingHandler(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // Create a response writer wrapper to capture status code. + wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} + + // Log request details. + log.Printf("[REQUEST] %s | %s | %s %s", + start.Format(time.RFC3339), + r.RemoteAddr, + r.Method, + r.URL.Path) + + // Call the actual handler. + handler.ServeHTTP(wrapped, r) + + // Log response details. + duration := time.Since(start) + log.Printf("[RESPONSE] %s | %s | %s %s | Status: %d | Duration: %v", + time.Now().Format(time.RFC3339), + r.RemoteAddr, + r.Method, + r.URL.Path, + wrapped.statusCode, + duration) + }) +} diff --git a/examples/http/main.go b/examples/http/main.go new file mode 100644 index 00000000..188674ae --- /dev/null +++ b/examples/http/main.go @@ -0,0 +1,200 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "flag" + "fmt" + "log" + "net/http" + "os" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + host = flag.String("host", "localhost", "host to connect to/listen on") + port = flag.Int("port", 8000, "port number to connect to/listen on") + proto = flag.String("proto", "http", "if set, use as proto:// part of URL (ignored for server)") +) + +func main() { + out := flag.CommandLine.Output() + flag.Usage = func() { + fmt.Fprintf(out, "Usage: %s [-proto ] [-port ]\n\n", os.Args[0]) + fmt.Fprintf(out, "This program demonstrates MCP over HTTP using the streamable transport.\n") + fmt.Fprintf(out, "It can run as either a server or client.\n\n") + fmt.Fprintf(out, "Options:\n") + flag.PrintDefaults() + fmt.Fprintf(out, "\nExamples:\n") + fmt.Fprintf(out, " Run as server: %s server\n", os.Args[0]) + fmt.Fprintf(out, " Run as client: %s client\n", os.Args[0]) + fmt.Fprintf(out, " Custom host/port: %s -port 9000 -host 0.0.0.0 server\n", os.Args[0]) + os.Exit(1) + } + flag.Parse() + + if flag.NArg() != 1 { + fmt.Fprintf(out, "Error: Must specify 'client' or 'server' as first argument\n") + flag.Usage() + } + mode := flag.Arg(0) + + switch mode { + case "server": + if *proto != "http" { + log.Fatalf("Server only works with 'http' (you passed proto=%s)", *proto) + } + runServer(fmt.Sprintf("%s:%d", *host, *port)) + case "client": + runClient(fmt.Sprintf("%s://%s:%d", *proto, *host, *port)) + default: + fmt.Fprintf(os.Stderr, "Error: Invalid mode '%s'. Must be 'client' or 'server'\n\n", mode) + flag.Usage() + } +} + +// GetTimeParams defines the parameters for the cityTime tool. +type GetTimeParams struct { + City string `json:"city" jsonschema:"City to get time for (nyc, sf, or boston)"` +} + +// getTime implements the tool that returns the current time for a given city. +func getTime(ctx context.Context, req *mcp.CallToolRequest, params *GetTimeParams) (*mcp.CallToolResult, any, error) { + // Define time zones for each city + locations := map[string]string{ + "nyc": "America/New_York", + "sf": "America/Los_Angeles", + "boston": "America/New_York", + } + + city := params.City + if city == "" { + city = "nyc" // Default to NYC + } + + // Get the timezone. + tzName, ok := locations[city] + if !ok { + return nil, nil, fmt.Errorf("unknown city: %s", city) + } + + // Load the location. + loc, err := time.LoadLocation(tzName) + if err != nil { + return nil, nil, fmt.Errorf("failed to load timezone: %w", err) + } + + // Get current time in that location. + now := time.Now().In(loc) + + // Format the response. + cityNames := map[string]string{ + "nyc": "New York City", + "sf": "San Francisco", + "boston": "Boston", + } + + response := fmt.Sprintf("The current time in %s is %s", + cityNames[city], + now.Format(time.RFC3339)) + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: response}, + }, + }, nil, nil +} + +func runServer(url string) { + // Create an MCP server. + server := mcp.NewServer(&mcp.Implementation{ + Name: "time-server", + Version: "1.0.0", + }, nil) + + // Add the cityTime tool. + mcp.AddTool(server, &mcp.Tool{ + Name: "cityTime", + Description: "Get the current time in NYC, San Francisco, or Boston", + }, getTime) + + // Create the streamable HTTP handler. + handler := mcp.NewStreamableHTTPHandler(func(req *http.Request) *mcp.Server { + return server + }, nil) + + handlerWithLogging := loggingHandler(handler) + + log.Printf("MCP server listening on %s", url) + log.Printf("Available tool: cityTime (cities: nyc, sf, boston)") + + // Start the HTTP server with logging handler. + if err := http.ListenAndServe(url, handlerWithLogging); err != nil { + log.Fatalf("Server failed: %v", err) + } +} + +func runClient(url string) { + ctx := context.Background() + + // Create the URL for the server. + log.Printf("Connecting to MCP server at %s", url) + + // Create an MCP client. + client := mcp.NewClient(&mcp.Implementation{ + Name: "time-client", + Version: "1.0.0", + }, nil) + + // Connect to the server. + session, err := client.Connect(ctx, &mcp.StreamableClientTransport{Endpoint: url}, nil) + if err != nil { + log.Fatalf("Failed to connect: %v", err) + } + defer session.Close() + + log.Printf("Connected to server (session ID: %s)", session.ID()) + + // First, list available tools. + log.Println("Listing available tools...") + toolsResult, err := session.ListTools(ctx, nil) + if err != nil { + log.Fatalf("Failed to list tools: %v", err) + } + + for _, tool := range toolsResult.Tools { + log.Printf(" - %s: %s\n", tool.Name, tool.Description) + } + + // Call the cityTime tool for each city. + cities := []string{"nyc", "sf", "boston"} + + log.Println("Getting time for each city...") + for _, city := range cities { + // Call the tool. + result, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "cityTime", + Arguments: map[string]any{ + "city": city, + }, + }) + if err != nil { + log.Printf("Failed to get time for %s: %v\n", city, err) + continue + } + + // Print the result. + for _, content := range result.Content { + if textContent, ok := content.(*mcp.TextContent); ok { + log.Printf(" %s", textContent.Text) + } + } + } + + log.Println("Client completed successfully") +} diff --git a/examples/rate-limiting/go.mod b/examples/rate-limiting/go.mod deleted file mode 100644 index 5ec49ddc..00000000 --- a/examples/rate-limiting/go.mod +++ /dev/null @@ -1,8 +0,0 @@ -module github.com/modelcontextprotocol/go-sdk/examples/rate-limiting - -go 1.25 - -require ( - github.com/modelcontextprotocol/go-sdk v0.0.0-20250625185707-09181c2c2e89 - golang.org/x/time v0.12.0 -) diff --git a/examples/rate-limiting/main.go b/examples/rate-limiting/main.go deleted file mode 100644 index 7e91b79f..00000000 --- a/examples/rate-limiting/main.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package main - -import ( - "context" - "errors" - "time" - - "github.com/modelcontextprotocol/go-sdk/mcp" - "golang.org/x/time/rate" -) - -// GlobalRateLimiterMiddleware creates a middleware that applies a global rate limit. -// Every request attempting to pass through will try to acquire a token. -// If a token cannot be acquired immediately, the request will be rejected. -func GlobalRateLimiterMiddleware[S mcp.Session](limiter *rate.Limiter) mcp.Middleware[S] { - return func(next mcp.MethodHandler[S]) mcp.MethodHandler[S] { - return func(ctx context.Context, session S, method string, params mcp.Params) (mcp.Result, error) { - if !limiter.Allow() { - return nil, errors.New("JSON RPC overloaded") - } - return next(ctx, session, method, params) - } - } -} - -// PerMethodRateLimiterMiddleware creates a middleware that applies rate limiting -// on a per-method basis. -// Methods not specified in limiters will not be rate limited by this middleware. -func PerMethodRateLimiterMiddleware[S mcp.Session](limiters map[string]*rate.Limiter) mcp.Middleware[S] { - return func(next mcp.MethodHandler[S]) mcp.MethodHandler[S] { - return func(ctx context.Context, session S, method string, params mcp.Params) (mcp.Result, error) { - if limiter, ok := limiters[method]; ok { - if !limiter.Allow() { - return nil, errors.New("JSON RPC overloaded") - } - } - return next(ctx, session, method, params) - } - } -} - -func main() { - server := mcp.NewServer("greeter1", "v0.0.1", nil) - server.AddReceivingMiddleware(GlobalRateLimiterMiddleware[*mcp.ServerSession](rate.NewLimiter(rate.Every(time.Second/5), 10))) - server.AddReceivingMiddleware(PerMethodRateLimiterMiddleware[*mcp.ServerSession](map[string]*rate.Limiter{ - "callTool": rate.NewLimiter(rate.Every(time.Second), 5), // once a second with a burst up to 5 - "listTools": rate.NewLimiter(rate.Every(time.Minute), 20), // once a minute with a burst up to 20 - })) - // Run Server logic. -} diff --git a/examples/server/auth-middleware/README.md b/examples/server/auth-middleware/README.md new file mode 100644 index 00000000..913022ff --- /dev/null +++ b/examples/server/auth-middleware/README.md @@ -0,0 +1,247 @@ +# MCP Server with Auth Middleware + +This example demonstrates how to integrate the Go MCP SDK's `auth.RequireBearerToken` middleware with an MCP server to provide authenticated access to MCP tools and resources. + +## Features + +The server provides authentication and authorization capabilities for MCP tools: + +### 1. Authentication Methods + +- **JWT Token Authentication**: JSON Web Token-based authentication +- **API Key Authentication**: API key-based authentication +- **Scope-based Access Control**: Permission-based access to MCP tools + +### 2. MCP Integration + +- **Authenticated MCP Tools**: Tools that require authentication and check permissions +- **Token Generation**: Utility endpoints for generating test tokens +- **Middleware Integration**: Seamless integration with MCP server handlers + +## Setup + +```bash +cd examples/server/auth-middleware +go mod tidy +go run main.go +``` + +## Testing + +```bash +# Run all tests +go test -v + +# Run benchmark tests +go test -bench=. + +# Generate coverage report +go test -cover +``` + +## Endpoints + +### Public Endpoints (No Authentication Required) + +- `GET /health` - Health check + +### MCP Endpoints (Authentication Required) + +- `POST /mcp/jwt` - MCP server with JWT authentication +- `POST /mcp/apikey` - MCP server with API key authentication + +### Utility Endpoints + +- `GET /generate-token` - Generate JWT token +- `POST /generate-api-key` - Generate API key + +## Available MCP Tools + +The server provides three authenticated MCP tools: + +### 1. Say Hi (`say_hi`) + +A simple greeting tool that requires authentication. + +**Parameters:** +- None required + +**Required Scopes:** +- Any authenticated user + +### 2. Get User Info (`get_user_info`) + +Retrieves user information based on the provided user ID. + +**Parameters:** +- `user_id` (string): The user ID to get information for + +**Required Scopes:** +- `read` permission + +### 3. Create Resource (`create_resource`) + +Creates a new resource with the provided details. + +**Parameters:** +- `name` (string): The name of the resource +- `description` (string): The description of the resource +- `content` (string): The content of the resource + +**Required Scopes:** +- `write` permission + +## Example Usage + +### 1. Generating JWT Token and Using MCP Tools + +```bash +# Generate a token +curl 'http://localhost:8080/generate-token?user_id=alice&scopes=read,write' + +# Use MCP tool with JWT authentication +curl -H 'Authorization: Bearer ' \ + -H 'Content-Type: application/json' \ + -d '{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"say_hi","arguments":{}}}' \ + http://localhost:8080/mcp/jwt +``` + +### 2. Generating API Key and Using MCP Tools + +```bash +# Generate an API key +curl -X POST 'http://localhost:8080/generate-api-key?user_id=bob&scopes=read' + +# Use MCP tool with API key authentication +curl -H 'Authorization: Bearer ' \ + -H 'Content-Type: application/json' \ + -d '{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"get_user_info","arguments":{"user_id":"test"}}}' \ + http://localhost:8080/mcp/apikey +``` + +### 3. Testing Scope Restrictions + +```bash +# Access MCP tool requiring write scope +curl -H 'Authorization: Bearer ' \ + -H 'Content-Type: application/json' \ + -d '{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"create_resource","arguments":{"name":"test","description":"test resource","content":"test content"}}}' \ + http://localhost:8080/mcp/jwt +``` + +## Core Concepts + +### Authentication Integration + +This example demonstrates how to integrate `auth.RequireBearerToken` middleware with an MCP server to provide authenticated access. The MCP server operates as an HTTP handler protected by authentication middleware. + +### Key Features + +1. **MCP Server Integration**: Create MCP server using `mcp.NewServer` +2. **Authentication Middleware**: Protect MCP handlers with `auth.RequireBearerToken` +3. **Token Verification**: Validate tokens using provided `TokenVerifier` functions +4. **Scope Checking**: Verify required permissions (scopes) are present +5. **Expiration Validation**: Check that tokens haven't expired +6. **Context Injection**: Add verified token information to request context +7. **Authenticated MCP Tools**: Tools that operate based on authentication information +8. **Error Handling**: Return appropriate HTTP status codes and error messages on authentication failure + +### Implementation + +```go +// Create MCP server +server := mcp.NewServer(&mcp.Implementation{Name: "authenticated-mcp-server"}, nil) + +// Create authentication middleware +authMiddleware := auth.RequireBearerToken(verifier, &auth.RequireBearerTokenOptions{ + Scopes: []string{"read", "write"}, +}) + +// Create MCP handler +handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return server +}, nil) + +// Apply authentication middleware to MCP handler +authenticatedHandler := authMiddleware(customMiddleware(handler)) +``` + +### Parameters + +- **verifier**: Function to verify tokens (`TokenVerifier` type) +- **opts**: Authentication options + - `Scopes`: List of required permissions + - `ResourceMetadataURL`: OAuth 2.0 resource metadata URL + +### Error Responses + +- **401 Unauthorized**: Token is invalid, expired, or missing +- **403 Forbidden**: Required scopes are insufficient +- **WWW-Authenticate Header**: Included when resource metadata URL is configured + +## Implementation Details + +### 1. TokenVerifier Implementation + +```go +func jwtVerifier(ctx context.Context, tokenString string) (*auth.TokenInfo, error) { + // JWT token verification logic + // On success: Return TokenInfo + // On failure: Return auth.ErrInvalidToken +} +``` + +### 2. Using Authentication Information in MCP Tools + +```go +// Get authentication information in MCP tool +func MyTool(ctx context.Context, req *mcp.CallToolRequest, args MyArgs) (*mcp.CallToolResult, any, error) { + // Extract authentication info from request + userInfo := req.Extra.TokenInfo + + // Check scopes + if !slices.Contains(userInfo.Scopes, "read") { + return nil, nil, fmt.Errorf("insufficient permissions: read scope required") + } + + // Execute tool logic + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Tool executed successfully"}, + }, + }, nil, nil +} +``` + +### 3. Middleware Composition + +```go +// Combine authentication middleware with custom middleware +authenticatedHandler := authMiddleware(customMiddleware(mcpHandler)) +``` + +## Security Best Practices + +1. **Environment Variables**: Use environment variables for JWT secrets in production +2. **Database Storage**: Store API keys in a database +3. **HTTPS Usage**: Always use HTTPS in production environments +4. **Token Expiration**: Set appropriate token expiration times +5. **Principle of Least Privilege**: Grant only the minimum required scopes + +## Use Cases + +**Ideal for:** + +- MCP servers requiring authentication and authorization +- Applications needing scope-based access control +- Systems requiring both JWT and API key authentication +- Projects needing secure MCP tool access +- Scenarios requiring audit trails and permission management + +**Examples:** + +- Enterprise MCP servers with user management +- Multi-tenant MCP applications +- Secure API gateways with MCP integration +- Development environments with authentication requirements +- Production systems requiring fine-grained access control diff --git a/examples/server/auth-middleware/go.mod b/examples/server/auth-middleware/go.mod new file mode 100644 index 00000000..f1ca77fa --- /dev/null +++ b/examples/server/auth-middleware/go.mod @@ -0,0 +1,15 @@ +module auth-middleware-example + +go 1.23.0 + +require ( + github.com/golang-jwt/jwt/v5 v5.2.2 + github.com/modelcontextprotocol/go-sdk v0.3.0 +) + +require ( + github.com/google/jsonschema-go v0.2.3 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect +) + +replace github.com/modelcontextprotocol/go-sdk => ../../../ diff --git a/examples/server/auth-middleware/go.sum b/examples/server/auth-middleware/go.sum new file mode 100644 index 00000000..6a392638 --- /dev/null +++ b/examples/server/auth-middleware/go.sum @@ -0,0 +1,10 @@ +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.2.3 h1:dkP3B96OtZKKFvdrUSaDkL+YDx8Uw9uC4Y+eukpCnmM= +github.com/google/jsonschema-go v0.2.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= diff --git a/examples/server/auth-middleware/main.go b/examples/server/auth-middleware/main.go new file mode 100644 index 00000000..75a38b86 --- /dev/null +++ b/examples/server/auth-middleware/main.go @@ -0,0 +1,378 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "flag" + "fmt" + "log" + "net/http" + "slices" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// This example demonstrates how to integrate auth.RequireBearerToken middleware +// with an MCP server to provide authenticated access to MCP tools and resources. + +var httpAddr = flag.String("http", ":8080", "HTTP address to listen on") + +// JWTClaims represents the claims in our JWT tokens. +// In a real application, you would include additional claims like issuer, audience, etc. +type JWTClaims struct { + UserID string `json:"user_id"` // User identifier + Scopes []string `json:"scopes"` // Permissions/roles for the user + jwt.RegisteredClaims +} + +// APIKey represents an API key with associated scopes. +// In production, this would be stored in a database with additional metadata. +type APIKey struct { + Key string `json:"key"` // The actual API key value + UserID string `json:"user_id"` // User identifier + Scopes []string `json:"scopes"` // Permissions/roles for this key +} + +// In-memory storage for API keys (in production, use a database). +// This is for demonstration purposes only. +var apiKeys = map[string]*APIKey{ + "sk-1234567890abcdef": { + Key: "sk-1234567890abcdef", + UserID: "user1", + Scopes: []string{"read", "write"}, + }, + "sk-abcdef1234567890": { + Key: "sk-abcdef1234567890", + UserID: "user2", + Scopes: []string{"read"}, + }, +} + +// JWT secret (in production, use environment variables). +// This should be a strong, randomly generated secret in real applications. +var jwtSecret = []byte("your-secret-key") + +// generateToken creates a JWT token for testing purposes. +// In a real application, this would be handled by your authentication service. +func generateToken(userID string, scopes []string, expiresIn time.Duration) (string, error) { + // Create JWT claims with user information and scopes. + claims := JWTClaims{ + UserID: userID, + Scopes: scopes, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiresIn)), // Token expiration + IssuedAt: jwt.NewNumericDate(time.Now()), // Token issuance time + NotBefore: jwt.NewNumericDate(time.Now()), // Token validity start time + }, + } + + // Create and sign the JWT token. + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString(jwtSecret) +} + +// verifyJWT verifies JWT tokens and returns TokenInfo for the auth middleware. +// This function implements the TokenVerifier interface required by auth.RequireBearerToken. +func verifyJWT(ctx context.Context, tokenString string, _ *http.Request) (*auth.TokenInfo, error) { + // Parse and validate the JWT token. + token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) { + // Verify the signing method is HMAC. + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return jwtSecret, nil + }) + if err != nil { + // Return standard error for invalid tokens. + return nil, fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) + } + + // Extract claims and verify token validity. + if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid { + return &auth.TokenInfo{ + Scopes: claims.Scopes, // User permissions + Expiration: claims.ExpiresAt.Time, // Token expiration time + }, nil + } + + return nil, fmt.Errorf("%w: invalid token claims", auth.ErrInvalidToken) +} + +// verifyAPIKey verifies API keys and returns TokenInfo for the auth middleware. +// This function implements the TokenVerifier interface required by auth.RequireBearerToken. +func verifyAPIKey(ctx context.Context, apiKey string, _ *http.Request) (*auth.TokenInfo, error) { + // Look up the API key in our storage. + key, exists := apiKeys[apiKey] + if !exists { + return nil, fmt.Errorf("%w: API key not found", auth.ErrInvalidToken) + } + + // API keys don't expire in this example, but you could add expiration logic here. + // For demonstration, we set a 24-hour expiration. + return &auth.TokenInfo{ + Scopes: key.Scopes, // User permissions + Expiration: time.Now().Add(24 * time.Hour), // 24 hour expiration + }, nil +} + +// MCP Tool Arguments +type getUserInfoArgs struct { + UserID string `json:"user_id" jsonschema:"the user ID to get information for"` +} + +type createResourceArgs struct { + Name string `json:"name" jsonschema:"the name of the resource"` + Description string `json:"description" jsonschema:"the description of the resource"` + Content string `json:"content" jsonschema:"the content of the resource"` +} + +// SayHi is a simple MCP tool that requires authentication +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args struct{}) (*mcp.CallToolResult, any, error) { + // Extract user information from request (v0.3.0+) + userInfo := req.Extra.TokenInfo + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: fmt.Sprintf("Hello! You have scopes: %v", userInfo.Scopes)}, + }, + }, nil, nil +} + +// GetUserInfo is an MCP tool that requires read scope +func GetUserInfo(ctx context.Context, req *mcp.CallToolRequest, args getUserInfoArgs) (*mcp.CallToolResult, any, error) { + // Extract user information from request (v0.3.0+) + userInfo := req.Extra.TokenInfo + + // Check if user has read scope. + if !slices.Contains(userInfo.Scopes, "read") { + return nil, nil, fmt.Errorf("insufficient permissions: read scope required") + } + + userData := map[string]any{ + "requested_user_id": args.UserID, + "your_scopes": userInfo.Scopes, + "message": "User information retrieved successfully", + } + + userDataJSON, err := json.Marshal(userData) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal user data: %w", err) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: string(userDataJSON)}, + }, + }, nil, nil +} + +// CreateResource is an MCP tool that requires write scope +func CreateResource(ctx context.Context, req *mcp.CallToolRequest, args createResourceArgs) (*mcp.CallToolResult, any, error) { + // Extract user information from request (v0.3.0+) + userInfo := req.Extra.TokenInfo + + // Check if user has write scope. + if !slices.Contains(userInfo.Scopes, "write") { + return nil, nil, fmt.Errorf("insufficient permissions: write scope required") + } + + resourceInfo := map[string]any{ + "name": args.Name, + "description": args.Description, + "content": args.Content, + "created_by": "authenticated_user", + "created_at": time.Now().Format(time.RFC3339), + } + + resourceInfoJSON, err := json.Marshal(resourceInfo) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal resource info: %w", err) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: fmt.Sprintf("Resource created successfully: %s", string(resourceInfoJSON))}, + }, + }, nil, nil +} + +// createMCPServer creates an MCP server with authentication-aware tools +func createMCPServer() *mcp.Server { + server := mcp.NewServer(&mcp.Implementation{Name: "authenticated-mcp-server"}, nil) + + // Add tools that require authentication. + mcp.AddTool(server, &mcp.Tool{ + Name: "say_hi", + Description: "A simple greeting tool that requires authentication", + }, SayHi) + + mcp.AddTool(server, &mcp.Tool{ + Name: "get_user_info", + Description: "Get user information (requires read scope)", + }, GetUserInfo) + + mcp.AddTool(server, &mcp.Tool{ + Name: "create_resource", + Description: "Create a new resource (requires write scope)", + }, CreateResource) + + return server +} + +func main() { + flag.Parse() + + // Create the MCP server. + server := createMCPServer() + + // Create authentication middleware. + jwtAuth := auth.RequireBearerToken(verifyJWT, &auth.RequireBearerTokenOptions{ + Scopes: []string{"read"}, // Require "read" permission + }) + + apiKeyAuth := auth.RequireBearerToken(verifyAPIKey, &auth.RequireBearerTokenOptions{ + Scopes: []string{"read"}, // Require "read" permission + }) + + // Create HTTP handler with authentication. + handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return server + }, nil) + + // Apply authentication middleware to the MCP handler. + authenticatedHandler := jwtAuth(handler) + apiKeyHandler := apiKeyAuth(handler) + + // Create router for different authentication methods. + http.HandleFunc("/mcp/jwt", authenticatedHandler.ServeHTTP) + http.HandleFunc("/mcp/apikey", apiKeyHandler.ServeHTTP) + + // Add utility endpoints for token generation. + http.HandleFunc("/generate-token", func(w http.ResponseWriter, r *http.Request) { + // Get user ID from query parameters (default: "test-user"). + userID := r.URL.Query().Get("user_id") + if userID == "" { + userID = "test-user" + } + + // Get scopes from query parameters (default: ["read", "write"]). + scopes := strings.Split(r.URL.Query().Get("scopes"), ",") + if len(scopes) == 1 && scopes[0] == "" { + scopes = []string{"read", "write"} + } + + // Get expiration time from query parameters (default: 1 hour). + expiresIn := 1 * time.Hour + if expStr := r.URL.Query().Get("expires_in"); expStr != "" { + if exp, err := time.ParseDuration(expStr); err == nil { + expiresIn = exp + } + } + + // Generate the JWT token. + token, err := generateToken(userID, scopes, expiresIn) + if err != nil { + http.Error(w, "Failed to generate token", http.StatusInternalServerError) + return + } + + // Return the generated token. + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "token": token, + "type": "Bearer", + }) + }) + + http.HandleFunc("/generate-api-key", func(w http.ResponseWriter, r *http.Request) { + // Generate a random API key using cryptographically secure random bytes. + bytes := make([]byte, 16) + if _, err := rand.Read(bytes); err != nil { + http.Error(w, "Failed to generate random bytes", http.StatusInternalServerError) + return + } + apiKey := "sk-" + base64.URLEncoding.EncodeToString(bytes) + + // Get user ID from query parameters (default: "test-user"). + userID := r.URL.Query().Get("user_id") + if userID == "" { + userID = "test-user" + } + + // Get scopes from query parameters (default: ["read"]). + scopes := strings.Split(r.URL.Query().Get("scopes"), ",") + if len(scopes) == 1 && scopes[0] == "" { + scopes = []string{"read"} + } + + // Store the new API key in our in-memory storage. + // In production, this would be stored in a database. + apiKeys[apiKey] = &APIKey{ + Key: apiKey, + UserID: userID, + Scopes: scopes, + } + + // Return the generated API key. + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "api_key": apiKey, + "type": "Bearer", + }) + }) + + // Health check endpoint. + http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "status": "healthy", + "time": time.Now().Format(time.RFC3339), + }) + }) + + // Start the HTTP server. + log.Println("Authenticated MCP Server") + log.Println("========================") + log.Println("Server starting on", *httpAddr) + log.Println() + log.Println("Available endpoints:") + log.Println(" GET /health - Health check (no auth)") + log.Println(" GET /generate-token - Generate JWT token") + log.Println(" POST /generate-api-key - Generate API key") + log.Println(" POST /mcp/jwt - MCP endpoint (JWT auth)") + log.Println(" POST /mcp/apikey - MCP endpoint (API key auth)") + log.Println() + log.Println("Available MCP Tools:") + log.Println(" - say_hi - Simple greeting (any auth)") + log.Println(" - get_user_info - Get user info (read scope)") + log.Println(" - create_resource - Create resource (write scope)") + log.Println() + log.Println("Example usage:") + log.Println(" # Generate a token") + log.Println(" curl 'http://localhost:8080/generate-token?user_id=alice&scopes=read,write'") + log.Println() + log.Println(" # Use MCP with JWT authentication") + log.Println(" curl -H 'Authorization: Bearer ' -H 'Content-Type: application/json' \\") + log.Println(" -d '{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"tools/call\",\"params\":{\"name\":\"say_hi\",\"arguments\":{}}}' \\") + log.Println(" http://localhost:8080/mcp/jwt") + log.Println() + log.Println(" # Generate an API key") + log.Println(" curl -X POST 'http://localhost:8080/generate-api-key?user_id=bob&scopes=read'") + log.Println() + log.Println(" # Use MCP with API key authentication") + log.Println(" curl -H 'Authorization: Bearer ' -H 'Content-Type: application/json' \\") + log.Println(" -d '{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"tools/call\",\"params\":{\"name\":\"get_user_info\",\"arguments\":{\"user_id\":\"test\"}}}' \\") + log.Println(" http://localhost:8080/mcp/apikey") + + log.Fatal(http.ListenAndServe(*httpAddr, nil)) +} diff --git a/examples/server/basic/main.go b/examples/server/basic/main.go new file mode 100644 index 00000000..54af6caa --- /dev/null +++ b/examples/server/basic/main.go @@ -0,0 +1,58 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "fmt" + "log" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type SayHiParams struct { + Name string `json:"name"` +} + +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args SayHiParams) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Hi " + args.Name}, + }, + }, nil, nil +} + +func main() { + ctx := context.Background() + clientTransport, serverTransport := mcp.NewInMemoryTransports() + + server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v0.0.1"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) + + serverSession, err := server.Connect(ctx, serverTransport, nil) + if err != nil { + log.Fatal(err) + } + + client := mcp.NewClient(&mcp.Implementation{Name: "client"}, nil) + clientSession, err := client.Connect(ctx, clientTransport, nil) + if err != nil { + log.Fatal(err) + } + + res, err := clientSession.CallTool(ctx, &mcp.CallToolParams{ + Name: "greet", + Arguments: map[string]any{"name": "user"}, + }) + if err != nil { + log.Fatal(err) + } + fmt.Println(res.Content[0].(*mcp.TextContent).Text) + + clientSession.Close() + serverSession.Wait() + + // Output: Hi user +} diff --git a/examples/completion/main.go b/examples/server/completion/main.go similarity index 79% rename from examples/completion/main.go rename to examples/server/completion/main.go index a24299bc..5220b0ee 100644 --- a/examples/completion/main.go +++ b/examples/server/completion/main.go @@ -16,17 +16,18 @@ import ( // a CompletionHandler to an MCP Server's options. func main() { // Define your custom CompletionHandler logic. - myCompletionHandler := func(_ context.Context, _ *mcp.ServerSession, params *mcp.CompleteParams) (*mcp.CompleteResult, error) { + // !+completionhandler + myCompletionHandler := func(_ context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { // In a real application, you'd implement actual completion logic here. // For this example, we return a fixed set of suggestions. var suggestions []string - switch params.Ref.Type { + switch req.Params.Ref.Type { case "ref/prompt": suggestions = []string{"suggestion1", "suggestion2", "suggestion3"} case "ref/resource": suggestions = []string{"suggestion4", "suggestion5", "suggestion6"} default: - return nil, fmt.Errorf("unrecognized content type %s", params.Ref.Type) + return nil, fmt.Errorf("unrecognized content type %s", req.Params.Ref.Type) } return &mcp.CompleteResult{ @@ -40,9 +41,10 @@ func main() { // Create the MCP Server instance and assign the handler. // No server running, just showing the configuration. - _ = mcp.NewServer("myServer", "v1.0.0", &mcp.ServerOptions{ + _ = mcp.NewServer(&mcp.Implementation{Name: "server"}, &mcp.ServerOptions{ CompletionHandler: myCompletionHandler, }) + // !-completionhandler log.Println("MCP Server instance created with a CompletionHandler assigned (but not running).") log.Println("This example demonstrates configuration, not live interaction.") diff --git a/examples/server/custom-transport/main.go b/examples/server/custom-transport/main.go new file mode 100644 index 00000000..c367cb62 --- /dev/null +++ b/examples/server/custom-transport/main.go @@ -0,0 +1,109 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bufio" + "context" + "errors" + "io" + "log" + "os" + + "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// IOTransport is a simplified implementation of a transport that communicates using +// newline-delimited JSON over an io.Reader and io.Writer. It is similar to ioTransport +// in transport.go and serves as a demonstration of how to implement a custom transport. +type IOTransport struct { + r *bufio.Reader + w io.Writer +} + +// NewIOTransport creates a new IOTransport with the given io.Reader and io.Writer. +func NewIOTransport(r io.Reader, w io.Writer) *IOTransport { + return &IOTransport{ + r: bufio.NewReader(r), + w: w, + } +} + +// ioConn is a connection that uses newlines to delimit messages. It implements [mcp.Connection]. +type ioConn struct { + r *bufio.Reader + w io.Writer +} + +// Connect implements [mcp.Transport.Connect] by creating a new ioConn. +func (t *IOTransport) Connect(ctx context.Context) (mcp.Connection, error) { + return &ioConn{ + r: t.r, + w: t.w, + }, nil +} + +// Read implements [mcp.Connection.Read], assuming messages are newline-delimited JSON. +func (t *ioConn) Read(context.Context) (jsonrpc.Message, error) { + data, err := t.r.ReadBytes('\n') + if err != nil { + return nil, err + } + + return jsonrpc.DecodeMessage(data[:len(data)-1]) +} + +// Write implements [mcp.Connection.Write], appending a newline delimiter after the message. +func (t *ioConn) Write(_ context.Context, msg jsonrpc.Message) error { + data, err := jsonrpc.EncodeMessage(msg) + if err != nil { + return err + } + + _, err1 := t.w.Write(data) + _, err2 := t.w.Write([]byte{'\n'}) + return errors.Join(err1, err2) +} + +// Close implements [mcp.Connection.Close]. Since this is a simplified example, it is a no-op. +func (t *ioConn) Close() error { + return nil +} + +// SessionID implements [mcp.Connection.SessionID]. Since this is a simplified example, +// it returns an empty session ID. +func (t *ioConn) SessionID() string { + return "" +} + +// HiArgs is the argument type for the SayHi tool. +type HiArgs struct { + Name string `json:"name" mcp:"the name to say hi to"` +} + +// SayHi is a tool handler that responds with a greeting. +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args HiArgs) (*mcp.CallToolResult, struct{}, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Hi " + args.Name}, + }, + }, struct{}{}, nil +} + +func main() { + server := mcp.NewServer(&mcp.Implementation{Name: "greeter"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) + + // Run the server with a custom IOTransport using stdio as the io.Reader and io.Writer. + transport := &IOTransport{ + r: bufio.NewReader(os.Stdin), + w: os.Stdout, + } + err := server.Run(context.Background(), transport) + if err != nil { + log.Println("[ERROR]: Failed to run server:", err) + } +} diff --git a/examples/server/distributed/main.go b/examples/server/distributed/main.go new file mode 100644 index 00000000..b1440402 --- /dev/null +++ b/examples/server/distributed/main.go @@ -0,0 +1,170 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The distributed command is an example of a distributed MCP server. +// +// It forks multiple child processes (according to the -child_ports flag), each +// of which is a streamable HTTP MCP server with the 'inc' tool, and proxies +// incoming http requests to them. +// +// Distributed MCP servers must be stateless, because there's no guarantee that +// subsequent requests for a session land on the same backend. However, they +// may still have logical session IDs, as can be seen with verbose logging +// (-v). +// +// Example: +// +// ./distributed -http=localhost:8080 -child_ports=8081,8082 +package main + +import ( + "context" + "flag" + "fmt" + "log" + "net/http" + "net/http/httputil" + "net/url" + "os" + "os/exec" + "os/signal" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +const childPortVar = "MCP_CHILD_PORT" + +var ( + httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") + childPorts = flag.String("child_ports", "", "comma-separated child ports to distribute to") + verbose = flag.Bool("v", false, "if set, enable verbose logging") +) + +func main() { + // This command runs as either a parent or a child, depending on whether + // childPortVar is set (a.k.a. the fork-and-exec trick). + // + // Each child is a streamable HTTP server, and the parent is a reverse proxy. + flag.Parse() + if v := os.Getenv(childPortVar); v != "" { + child(v) + } else { + parent() + } +} + +func parent() { + exe, err := os.Executable() + if err != nil { + log.Fatal(err) + } + + if *httpAddr == "" { + log.Fatal("must provide -http") + } + if *childPorts == "" { + log.Fatal("must provide -child_ports") + } + + // Ensure that children are cleaned up on CTRL-C + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) + defer stop() + + // Start the child processes. + ports := strings.Split(*childPorts, ",") + var wg sync.WaitGroup + childURLs := make([]*url.URL, len(ports)) + for i, port := range ports { + childURL := fmt.Sprintf("http://localhost:%s", port) + childURLs[i], err = url.Parse(childURL) + if err != nil { + log.Fatal(err) + } + cmd := exec.CommandContext(ctx, exe, os.Args[1:]...) + cmd.Env = append(os.Environ(), fmt.Sprintf("%s=%s", childPortVar, port)) + cmd.Stderr = os.Stderr + + wg.Add(1) + go func() { + defer wg.Done() + log.Printf("starting child %d at %s", i, childURL) + if err := cmd.Run(); err != nil && ctx.Err() == nil { + log.Printf("child %d failed: %v", i, err) + } else { + log.Printf("child %d exited gracefully", i) + } + }() + } + + // Start a reverse proxy that round-robin's requests to each backend. + var nextBackend atomic.Int64 + server := http.Server{ + Addr: *httpAddr, + Handler: &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + child := int(nextBackend.Add(1)) % len(childURLs) + if *verbose { + log.Printf("dispatching to localhost:%s", ports[child]) + } + r.SetURL(childURLs[child]) + }, + }, + } + + wg.Add(1) + go func() { + defer wg.Done() + if err := server.ListenAndServe(); err != nil && ctx.Err() == nil { + log.Printf("Server failed: %v", err) + } + }() + + log.Printf("Serving at %s (CTRL-C to cancel)", *httpAddr) + + <-ctx.Done() + stop() // restore the interrupt signal + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Attempt the graceful shutdown. + if err := server.Shutdown(shutdownCtx); err != nil { + log.Fatalf("Server shutdown failed: %v", err) + } + + // Wait for the subprocesses and http server to stop. + wg.Wait() + + log.Println("Server shutdown gracefully.") +} + +func child(port string) { + // Create a server with a single tool that increments a counter and sends a notification. + server := mcp.NewServer(&mcp.Implementation{Name: "counter"}, nil) + + var count atomic.Int64 + inc := func(ctx context.Context, req *mcp.CallToolRequest, args struct{}) (*mcp.CallToolResult, struct{ Count int64 }, error) { + n := count.Add(1) + if *verbose { + log.Printf("request %d (session %s)", n, req.Session.ID()) + } + // Send a notification in the context of the request + // Hint: in stateless mode, at least log level 'info' is required to send notifications + req.Session.Log(ctx, &mcp.LoggingMessageParams{Data: fmt.Sprintf("request %d (session %s)", n, req.Session.ID()), Level: "info"}) + return nil, struct{ Count int64 }{n}, nil + } + mcp.AddTool(server, &mcp.Tool{Name: "inc"}, inc) + + handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, &mcp.StreamableHTTPOptions{ + Stateless: true, + }) + log.Printf("child listening on localhost:%s", port) + log.Fatal(http.ListenAndServe(fmt.Sprintf("localhost:%s", port), handler)) +} diff --git a/examples/server/elicitation/main.go b/examples/server/elicitation/main.go new file mode 100644 index 00000000..59bc25cf --- /dev/null +++ b/examples/server/elicitation/main.go @@ -0,0 +1,86 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "fmt" + "log" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func main() { + ctx := context.Background() + clientTransport, serverTransport := mcp.NewInMemoryTransports() + + // Create server + server := mcp.NewServer(&mcp.Implementation{Name: "config-server", Version: "v1.0.0"}, nil) + + serverSession, err := server.Connect(ctx, serverTransport, nil) + if err != nil { + log.Fatal(err) + } + + // Create client with elicitation handler + // Note: Never use elicitation for sensitive data like API keys or passwords + client := mcp.NewClient(&mcp.Implementation{Name: "config-client", Version: "v1.0.0"}, &mcp.ClientOptions{ + ElicitationHandler: func(ctx context.Context, request *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + fmt.Printf("Server requests: %s\n", request.Params.Message) + + // In a real application, this would prompt the user for input + // Here we simulate user providing configuration data + return &mcp.ElicitResult{ + Action: "accept", + Content: map[string]any{ + "serverEndpoint": "https://api.example.com", + "maxRetries": float64(3), + "enableLogs": true, + }, + }, nil + }, + }) + + _, err = client.Connect(ctx, clientTransport, nil) + if err != nil { + log.Fatal(err) + } + + // Server requests user configuration via elicitation + configSchema := &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "serverEndpoint": {Type: "string", Description: "Server endpoint URL"}, + "maxRetries": {Type: "number", Minimum: ptr(1.0), Maximum: ptr(10.0)}, + "enableLogs": {Type: "boolean", Description: "Enable debug logging"}, + }, + Required: []string{"serverEndpoint"}, + } + + result, err := serverSession.Elicit(ctx, &mcp.ElicitParams{ + Message: "Please provide your configuration settings", + RequestedSchema: configSchema, + }) + if err != nil { + log.Fatal(err) + } + + if result.Action == "accept" { + fmt.Printf("Configuration received: Endpoint: %v, Max Retries: %.0f, Logs: %v\n", + result.Content["serverEndpoint"], + result.Content["maxRetries"], + result.Content["enableLogs"]) + } + + // Output: + // Server requests: Please provide your configuration settings + // Configuration received: Endpoint: https://api.example.com, Max Retries: 3, Logs: true +} + +// ptr is a helper function to create pointers for schema constraints +func ptr[T any](v T) *T { + return &v +} diff --git a/examples/server/everything/main.go b/examples/server/everything/main.go new file mode 100644 index 00000000..0b81919f --- /dev/null +++ b/examples/server/everything/main.go @@ -0,0 +1,224 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The everything server implements all supported features of an MCP server. +package main + +import ( + "context" + "flag" + "fmt" + "log" + "net/http" + _ "net/http/pprof" + "net/url" + "os" + "runtime" + "strings" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") + pprofAddr = flag.String("pprof", "", "if set, host the pprof debugging server at this address") +) + +func main() { + flag.Parse() + + if *pprofAddr != "" { + // For debugging memory leaks, add an endpoint to trigger a few garbage + // collection cycles and ensure the /debug/pprof/heap endpoint only reports + // reachable memory. + http.DefaultServeMux.Handle("/gc", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + for range 3 { + runtime.GC() + } + fmt.Fprintln(w, "GC'ed") + })) + go func() { + // DefaultServeMux was mutated by the /debug/pprof import. + http.ListenAndServe(*pprofAddr, http.DefaultServeMux) + }() + } + + opts := &mcp.ServerOptions{ + Instructions: "Use this server!", + CompletionHandler: complete, // support completions by setting this handler + } + + server := mcp.NewServer(&mcp.Implementation{Name: "everything"}, opts) + + // Add tools that exercise different features of the protocol. + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, contentTool) + mcp.AddTool(server, &mcp.Tool{Name: "greet (structured)"}, structuredTool) // returns structured output + mcp.AddTool(server, &mcp.Tool{Name: "ping"}, pingingTool) // performs a ping + mcp.AddTool(server, &mcp.Tool{Name: "log"}, loggingTool) // performs a log + mcp.AddTool(server, &mcp.Tool{Name: "sample"}, samplingTool) // performs sampling + mcp.AddTool(server, &mcp.Tool{Name: "elicit"}, elicitingTool) // performs elicitation + mcp.AddTool(server, &mcp.Tool{Name: "roots"}, rootsTool) // lists roots + + // Add a basic prompt. + server.AddPrompt(&mcp.Prompt{Name: "greet"}, prompt) + + // Add an embedded resource. + server.AddResource(&mcp.Resource{ + Name: "info", + MIMEType: "text/plain", + URI: "embedded:info", + }, embeddedResource) + + // Serve over stdio, or streamable HTTP if -http is set. + if *httpAddr != "" { + handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, nil) + log.Printf("MCP handler listening at %s", *httpAddr) + if *pprofAddr != "" { + log.Printf("pprof listening at http://%s/debug/pprof", *pprofAddr) + } + log.Fatal(http.ListenAndServe(*httpAddr, handler)) + } else { + t := &mcp.LoggingTransport{Transport: &mcp.StdioTransport{}, Writer: os.Stderr} + if err := server.Run(context.Background(), t); err != nil { + log.Printf("Server failed: %v", err) + } + } +} + +func prompt(ctx context.Context, req *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return &mcp.GetPromptResult{ + Description: "Hi prompt", + Messages: []*mcp.PromptMessage{ + { + Role: "user", + Content: &mcp.TextContent{Text: "Say hi to " + req.Params.Arguments["name"]}, + }, + }, + }, nil +} + +var embeddedResources = map[string]string{ + "info": "This is the hello example server.", +} + +func embeddedResource(_ context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + u, err := url.Parse(req.Params.URI) + if err != nil { + return nil, err + } + if u.Scheme != "embedded" { + return nil, fmt.Errorf("wrong scheme: %q", u.Scheme) + } + key := u.Opaque + text, ok := embeddedResources[key] + if !ok { + return nil, fmt.Errorf("no embedded resource named %q", key) + } + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{ + {URI: req.Params.URI, MIMEType: "text/plain", Text: text}, + }, + }, nil +} + +type args struct { + Name string `json:"name" jsonschema:"the name to say hi to"` +} + +// contentTool is a tool that returns unstructured content. +// +// Since its output type is 'any', no output schema is created. +func contentTool(ctx context.Context, req *mcp.CallToolRequest, args args) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Hi " + args.Name}, + }, + }, nil, nil +} + +type result struct { + Message string `json:"message" jsonschema:"the message to convey"` +} + +// structuredTool returns a structured result. +func structuredTool(ctx context.Context, req *mcp.CallToolRequest, args *args) (*mcp.CallToolResult, *result, error) { + return nil, &result{Message: "Hi " + args.Name}, nil +} + +func pingingTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + if err := req.Session.Ping(ctx, nil); err != nil { + return nil, nil, fmt.Errorf("ping failed") + } + return nil, nil, nil +} + +func loggingTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + if err := req.Session.Log(ctx, &mcp.LoggingMessageParams{ + Data: "something happened!", + Level: "error", + }); err != nil { + return nil, nil, fmt.Errorf("log failed") + } + return nil, nil, nil +} + +func rootsTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + res, err := req.Session.ListRoots(ctx, nil) + if err != nil { + return nil, nil, fmt.Errorf("listing roots failed: %v", err) + } + var allroots []string + for _, r := range res.Roots { + allroots = append(allroots, fmt.Sprintf("%s:%s", r.Name, r.URI)) + } + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: strings.Join(allroots, ",")}, + }, + }, nil, nil +} + +func samplingTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + res, err := req.Session.CreateMessage(ctx, new(mcp.CreateMessageParams)) + if err != nil { + return nil, nil, fmt.Errorf("sampling failed: %v", err) + } + return &mcp.CallToolResult{ + Content: []mcp.Content{ + res.Content, + }, + }, nil, nil +} + +func elicitingTool(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + res, err := req.Session.Elicit(ctx, &mcp.ElicitParams{ + Message: "provide a random string", + RequestedSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "random": {Type: "string"}, + }, + }, + }) + if err != nil { + return nil, nil, fmt.Errorf("eliciting failed: %v", err) + } + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: res.Content["random"].(string)}, + }, + }, nil, nil +} + +func complete(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { + return &mcp.CompleteResult{ + Completion: mcp.CompletionResultDetails{ + Total: 1, + Values: []string{req.Params.Argument.Value + "x"}, + }, + }, nil +} diff --git a/examples/server/hello/main.go b/examples/server/hello/main.go new file mode 100644 index 00000000..796feff8 --- /dev/null +++ b/examples/server/hello/main.go @@ -0,0 +1,46 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The hello server contains a single tool that says hi to the user. +// +// It runs over the stdio transport. +package main + +import ( + "context" + "log" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func main() { + // Create a server with a single tool that says "Hi". + server := mcp.NewServer(&mcp.Implementation{Name: "greeter"}, nil) + + // Using the generic AddTool automatically populates the the input and output + // schema of the tool. + // + // The schema considers 'json' and 'jsonschema' struct tags to get argument + // names and descriptions. + type args struct { + Name string `json:"name" jsonschema:"the person to greet"` + } + mcp.AddTool(server, &mcp.Tool{ + Name: "greet", + Description: "say hi", + }, func(ctx context.Context, req *mcp.CallToolRequest, args args) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Hi " + args.Name}, + }, + }, nil, nil + }) + + // server.Run runs the server on the given transport. + // + // In this case, the server communicates over stdin/stdout. + if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { + log.Printf("Server failed: %v", err) + } +} diff --git a/examples/server/memory/kb.go b/examples/server/memory/kb.go new file mode 100644 index 00000000..ad83ca0b --- /dev/null +++ b/examples/server/memory/kb.go @@ -0,0 +1,567 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "slices" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// Entity represents a knowledge graph node with observations. +type Entity struct { + Name string `json:"name"` + EntityType string `json:"entityType"` + Observations []string `json:"observations"` +} + +// Relation represents a directed edge between two entities. +type Relation struct { + From string `json:"from"` + To string `json:"to"` + RelationType string `json:"relationType"` +} + +// Observation contains facts about an entity. +type Observation struct { + EntityName string `json:"entityName"` + Contents []string `json:"contents"` + + Observations []string `json:"observations,omitempty"` // Used for deletion operations +} + +// KnowledgeGraph represents the complete graph structure. +type KnowledgeGraph struct { + Entities []Entity `json:"entities"` + Relations []Relation `json:"relations"` +} + +// store provides persistence interface for knowledge base data. +type store interface { + Read() ([]byte, error) + Write(data []byte) error +} + +// memoryStore implements in-memory storage that doesn't persist across restarts. +type memoryStore struct { + data []byte +} + +// Read returns the in-memory data. +func (ms *memoryStore) Read() ([]byte, error) { + return ms.data, nil +} + +// Write stores data in memory. +func (ms *memoryStore) Write(data []byte) error { + ms.data = data + return nil +} + +// fileStore implements file-based storage for persistent knowledge base. +type fileStore struct { + path string +} + +// Read loads data from file, returning empty slice if file doesn't exist. +func (fs *fileStore) Read() ([]byte, error) { + data, err := os.ReadFile(fs.path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to read file %s: %w", fs.path, err) + } + return data, nil +} + +// Write saves data to file with 0600 permissions. +func (fs *fileStore) Write(data []byte) error { + if err := os.WriteFile(fs.path, data, 0o600); err != nil { + return fmt.Errorf("failed to write file %s: %w", fs.path, err) + } + return nil +} + +// knowledgeBase manages entities and relations with persistent storage. +type knowledgeBase struct { + s store +} + +// kbItem represents a single item in persistent storage (entity or relation). +type kbItem struct { + Type string `json:"type"` + + // Entity fields (when Type == "entity") + Name string `json:"name,omitempty"` + EntityType string `json:"entityType,omitempty"` + Observations []string `json:"observations,omitempty"` + + // Relation fields (when Type == "relation") + From string `json:"from,omitempty"` + To string `json:"to,omitempty"` + RelationType string `json:"relationType,omitempty"` +} + +// loadGraph deserializes the knowledge graph from storage. +func (k knowledgeBase) loadGraph() (KnowledgeGraph, error) { + data, err := k.s.Read() + if err != nil { + return KnowledgeGraph{}, fmt.Errorf("failed to read from store: %w", err) + } + + if len(data) == 0 { + return KnowledgeGraph{}, nil + } + + var items []kbItem + if err := json.Unmarshal(data, &items); err != nil { + return KnowledgeGraph{}, fmt.Errorf("failed to unmarshal from store: %w", err) + } + + graph := KnowledgeGraph{} + + for _, item := range items { + switch item.Type { + case "entity": + graph.Entities = append(graph.Entities, Entity{ + Name: item.Name, + EntityType: item.EntityType, + Observations: item.Observations, + }) + case "relation": + graph.Relations = append(graph.Relations, Relation{ + From: item.From, + To: item.To, + RelationType: item.RelationType, + }) + } + } + + return graph, nil +} + +// saveGraph serializes and persists the knowledge graph to storage. +func (k knowledgeBase) saveGraph(graph KnowledgeGraph) error { + items := make([]kbItem, 0, len(graph.Entities)+len(graph.Relations)) + + for _, entity := range graph.Entities { + items = append(items, kbItem{ + Type: "entity", + Name: entity.Name, + EntityType: entity.EntityType, + Observations: entity.Observations, + }) + } + + for _, relation := range graph.Relations { + items = append(items, kbItem{ + Type: "relation", + From: relation.From, + To: relation.To, + RelationType: relation.RelationType, + }) + } + + itemsJSON, err := json.Marshal(items) + if err != nil { + return fmt.Errorf("failed to marshal items: %w", err) + } + + if err := k.s.Write(itemsJSON); err != nil { + return fmt.Errorf("failed to write to store: %w", err) + } + return nil +} + +// createEntities adds new entities to the graph, skipping duplicates by name. +// It returns the new entities that were actually added. +func (k knowledgeBase) createEntities(entities []Entity) ([]Entity, error) { + graph, err := k.loadGraph() + if err != nil { + return nil, err + } + + var newEntities []Entity + for _, entity := range entities { + if !slices.ContainsFunc(graph.Entities, func(e Entity) bool { return e.Name == entity.Name }) { + newEntities = append(newEntities, entity) + graph.Entities = append(graph.Entities, entity) + } + } + + if err := k.saveGraph(graph); err != nil { + return nil, err + } + + return newEntities, nil +} + +// createRelations adds new relations to the graph, skipping exact duplicates. +// It returns the new relations that were actually added. +func (k knowledgeBase) createRelations(relations []Relation) ([]Relation, error) { + graph, err := k.loadGraph() + if err != nil { + return nil, err + } + + var newRelations []Relation + for _, relation := range relations { + exists := slices.ContainsFunc(graph.Relations, func(r Relation) bool { + return r.From == relation.From && + r.To == relation.To && + r.RelationType == relation.RelationType + }) + if !exists { + newRelations = append(newRelations, relation) + graph.Relations = append(graph.Relations, relation) + } + } + + if err := k.saveGraph(graph); err != nil { + return nil, err + } + + return newRelations, nil +} + +// addObservations appends new observations to existing entities. +// It returns the new observations that were actually added. +func (k knowledgeBase) addObservations(observations []Observation) ([]Observation, error) { + graph, err := k.loadGraph() + if err != nil { + return nil, err + } + + var results []Observation + + for _, obs := range observations { + entityIndex := slices.IndexFunc(graph.Entities, func(e Entity) bool { return e.Name == obs.EntityName }) + if entityIndex == -1 { + return nil, fmt.Errorf("entity with name %s not found", obs.EntityName) + } + + var newObservations []string + for _, content := range obs.Contents { + if !slices.Contains(graph.Entities[entityIndex].Observations, content) { + newObservations = append(newObservations, content) + graph.Entities[entityIndex].Observations = append(graph.Entities[entityIndex].Observations, content) + } + } + + results = append(results, Observation{ + EntityName: obs.EntityName, + Contents: newObservations, + }) + } + + if err := k.saveGraph(graph); err != nil { + return nil, err + } + + return results, nil +} + +// deleteEntities removes entities and their associated relations. +func (k knowledgeBase) deleteEntities(entityNames []string) error { + graph, err := k.loadGraph() + if err != nil { + return err + } + + // Create map for quick lookup + entitiesToDelete := make(map[string]bool) + for _, name := range entityNames { + entitiesToDelete[name] = true + } + + // Filter entities using slices.DeleteFunc + graph.Entities = slices.DeleteFunc(graph.Entities, func(entity Entity) bool { + return entitiesToDelete[entity.Name] + }) + + // Filter relations using slices.DeleteFunc + graph.Relations = slices.DeleteFunc(graph.Relations, func(relation Relation) bool { + return entitiesToDelete[relation.From] || entitiesToDelete[relation.To] + }) + + return k.saveGraph(graph) +} + +// deleteObservations removes specific observations from entities. +func (k knowledgeBase) deleteObservations(deletions []Observation) error { + graph, err := k.loadGraph() + if err != nil { + return err + } + + for _, deletion := range deletions { + entityIndex := slices.IndexFunc(graph.Entities, func(e Entity) bool { + return e.Name == deletion.EntityName + }) + if entityIndex == -1 { + continue + } + + // Create a map for quick lookup + observationsToDelete := make(map[string]bool) + for _, observation := range deletion.Observations { + observationsToDelete[observation] = true + } + + // Filter observations using slices.DeleteFunc + graph.Entities[entityIndex].Observations = slices.DeleteFunc(graph.Entities[entityIndex].Observations, func(observation string) bool { + return observationsToDelete[observation] + }) + } + + return k.saveGraph(graph) +} + +// deleteRelations removes specific relations from the graph. +func (k knowledgeBase) deleteRelations(relations []Relation) error { + graph, err := k.loadGraph() + if err != nil { + return err + } + + // Filter relations using slices.DeleteFunc and slices.ContainsFunc + graph.Relations = slices.DeleteFunc(graph.Relations, func(existingRelation Relation) bool { + return slices.ContainsFunc(relations, func(relationToDelete Relation) bool { + return existingRelation.From == relationToDelete.From && + existingRelation.To == relationToDelete.To && + existingRelation.RelationType == relationToDelete.RelationType + }) + }) + return k.saveGraph(graph) +} + +// searchNodes filters entities and relations matching the query string. +func (k knowledgeBase) searchNodes(query string) (KnowledgeGraph, error) { + graph, err := k.loadGraph() + if err != nil { + return KnowledgeGraph{}, err + } + + queryLower := strings.ToLower(query) + var filteredEntities []Entity + + // Filter entities + for _, entity := range graph.Entities { + if strings.Contains(strings.ToLower(entity.Name), queryLower) || + strings.Contains(strings.ToLower(entity.EntityType), queryLower) { + filteredEntities = append(filteredEntities, entity) + continue + } + + // Check observations + for _, observation := range entity.Observations { + if strings.Contains(strings.ToLower(observation), queryLower) { + filteredEntities = append(filteredEntities, entity) + break + } + } + } + + // Create map for quick entity lookup + filteredEntityNames := make(map[string]bool) + for _, entity := range filteredEntities { + filteredEntityNames[entity.Name] = true + } + + // Filter relations + var filteredRelations []Relation + for _, relation := range graph.Relations { + if filteredEntityNames[relation.From] && filteredEntityNames[relation.To] { + filteredRelations = append(filteredRelations, relation) + } + } + + return KnowledgeGraph{ + Entities: filteredEntities, + Relations: filteredRelations, + }, nil +} + +// openNodes returns entities with specified names and their interconnecting relations. +func (k knowledgeBase) openNodes(names []string) (KnowledgeGraph, error) { + graph, err := k.loadGraph() + if err != nil { + return KnowledgeGraph{}, err + } + + // Create map for quick name lookup + nameSet := make(map[string]bool) + for _, name := range names { + nameSet[name] = true + } + + // Filter entities + var filteredEntities []Entity + for _, entity := range graph.Entities { + if nameSet[entity.Name] { + filteredEntities = append(filteredEntities, entity) + } + } + + // Create map for quick entity lookup + filteredEntityNames := make(map[string]bool) + for _, entity := range filteredEntities { + filteredEntityNames[entity.Name] = true + } + + // Filter relations + var filteredRelations []Relation + for _, relation := range graph.Relations { + if filteredEntityNames[relation.From] && filteredEntityNames[relation.To] { + filteredRelations = append(filteredRelations, relation) + } + } + + return KnowledgeGraph{ + Entities: filteredEntities, + Relations: filteredRelations, + }, nil +} + +func (k knowledgeBase) CreateEntities(ctx context.Context, req *mcp.CallToolRequest, args CreateEntitiesArgs) (*mcp.CallToolResult, CreateEntitiesResult, error) { + var res mcp.CallToolResult + + entities, err := k.createEntities(args.Entities) + if err != nil { + return nil, CreateEntitiesResult{}, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Entities created successfully"}, + } + return &res, CreateEntitiesResult{Entities: entities}, nil +} + +func (k knowledgeBase) CreateRelations(ctx context.Context, req *mcp.CallToolRequest, args CreateRelationsArgs) (*mcp.CallToolResult, CreateRelationsResult, error) { + var res mcp.CallToolResult + + relations, err := k.createRelations(args.Relations) + if err != nil { + return nil, CreateRelationsResult{}, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Relations created successfully"}, + } + + return &res, CreateRelationsResult{Relations: relations}, nil +} + +func (k knowledgeBase) AddObservations(ctx context.Context, req *mcp.CallToolRequest, args AddObservationsArgs) (*mcp.CallToolResult, AddObservationsResult, error) { + var res mcp.CallToolResult + + observations, err := k.addObservations(args.Observations) + if err != nil { + return nil, AddObservationsResult{}, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Observations added successfully"}, + } + + return &res, AddObservationsResult{ + Observations: observations, + }, nil +} + +func (k knowledgeBase) DeleteEntities(ctx context.Context, req *mcp.CallToolRequest, args DeleteEntitiesArgs) (*mcp.CallToolResult, any, error) { + var res mcp.CallToolResult + + err := k.deleteEntities(args.EntityNames) + if err != nil { + return nil, nil, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Entities deleted successfully"}, + } + + return &res, nil, nil +} + +func (k knowledgeBase) DeleteObservations(ctx context.Context, req *mcp.CallToolRequest, args DeleteObservationsArgs) (*mcp.CallToolResult, any, error) { + var res mcp.CallToolResult + + err := k.deleteObservations(args.Deletions) + if err != nil { + return nil, nil, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Observations deleted successfully"}, + } + + return &res, nil, nil +} + +func (k knowledgeBase) DeleteRelations(ctx context.Context, req *mcp.CallToolRequest, args DeleteRelationsArgs) (*mcp.CallToolResult, struct{}, error) { + var res mcp.CallToolResult + + err := k.deleteRelations(args.Relations) + if err != nil { + return nil, struct{}{}, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Relations deleted successfully"}, + } + + return &res, struct{}{}, nil +} + +func (k knowledgeBase) ReadGraph(ctx context.Context, req *mcp.CallToolRequest, args any) (*mcp.CallToolResult, KnowledgeGraph, error) { + var res mcp.CallToolResult + + graph, err := k.loadGraph() + if err != nil { + return nil, KnowledgeGraph{}, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Graph read successfully"}, + } + + return &res, graph, nil +} + +func (k knowledgeBase) SearchNodes(ctx context.Context, req *mcp.CallToolRequest, args SearchNodesArgs) (*mcp.CallToolResult, KnowledgeGraph, error) { + var res mcp.CallToolResult + + graph, err := k.searchNodes(args.Query) + if err != nil { + return nil, KnowledgeGraph{}, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Nodes searched successfully"}, + } + + return &res, graph, nil +} + +func (k knowledgeBase) OpenNodes(ctx context.Context, req *mcp.CallToolRequest, args OpenNodesArgs) (*mcp.CallToolResult, KnowledgeGraph, error) { + var res mcp.CallToolResult + + graph, err := k.openNodes(args.Names) + if err != nil { + return nil, KnowledgeGraph{}, err + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Nodes opened successfully"}, + } + return &res, graph, nil +} diff --git a/examples/server/memory/kb_test.go b/examples/server/memory/kb_test.go new file mode 100644 index 00000000..5d40ae64 --- /dev/null +++ b/examples/server/memory/kb_test.go @@ -0,0 +1,649 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "reflect" + "slices" + "strings" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// stores provides test factories for both storage implementations. +func stores() map[string]func(t *testing.T) store { + return map[string]func(t *testing.T) store{ + "file": func(t *testing.T) store { + tempDir, err := os.MkdirTemp("", "kb-test-file-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + t.Cleanup(func() { os.RemoveAll(tempDir) }) + return &fileStore{path: filepath.Join(tempDir, "test-memory.json")} + }, + "memory": func(t *testing.T) store { + return &memoryStore{} + }, + } +} + +// TestKnowledgeBaseOperations verifies CRUD operations work correctly. +func TestKnowledgeBaseOperations(t *testing.T) { + for name, newStore := range stores() { + t.Run(name, func(t *testing.T) { + s := newStore(t) + kb := knowledgeBase{s: s} + + // Verify empty graph loads correctly + graph, err := kb.loadGraph() + if err != nil { + t.Fatalf("failed to load empty graph: %v", err) + } + if len(graph.Entities) != 0 || len(graph.Relations) != 0 { + t.Errorf("expected empty graph, got %+v", graph) + } + + // Create and verify entities + testEntities := []Entity{ + { + Name: "Alice", + EntityType: "Person", + Observations: []string{"Likes coffee"}, + }, + { + Name: "Bob", + EntityType: "Person", + Observations: []string{"Likes tea"}, + }, + } + + createdEntities, err := kb.createEntities(testEntities) + if err != nil { + t.Fatalf("failed to create entities: %v", err) + } + if len(createdEntities) != 2 { + t.Errorf("expected 2 created entities, got %d", len(createdEntities)) + } + + // Verify entities persist + graph, err = kb.loadGraph() + if err != nil { + t.Fatalf("failed to read graph: %v", err) + } + if len(graph.Entities) != 2 { + t.Errorf("expected 2 entities, got %d", len(graph.Entities)) + } + + // Create and verify relations + testRelations := []Relation{ + { + From: "Alice", + To: "Bob", + RelationType: "friend", + }, + } + + createdRelations, err := kb.createRelations(testRelations) + if err != nil { + t.Fatalf("failed to create relations: %v", err) + } + if len(createdRelations) != 1 { + t.Errorf("expected 1 created relation, got %d", len(createdRelations)) + } + + // Add observations to entities + testObservations := []Observation{ + { + EntityName: "Alice", + Contents: []string{"Works as developer", "Lives in New York"}, + }, + } + + addedObservations, err := kb.addObservations(testObservations) + if err != nil { + t.Fatalf("failed to add observations: %v", err) + } + if len(addedObservations) != 1 || len(addedObservations[0].Contents) != 2 { + t.Errorf("expected 1 observation with 2 contents, got %+v", addedObservations) + } + + // Search nodes by content + searchResult, err := kb.searchNodes("developer") + if err != nil { + t.Fatalf("failed to search nodes: %v", err) + } + if len(searchResult.Entities) != 1 || searchResult.Entities[0].Name != "Alice" { + t.Errorf("expected to find Alice when searching for 'developer', got %+v", searchResult) + } + + // Retrieve specific nodes + openResult, err := kb.openNodes([]string{"Bob"}) + if err != nil { + t.Fatalf("failed to open nodes: %v", err) + } + if len(openResult.Entities) != 1 || openResult.Entities[0].Name != "Bob" { + t.Errorf("expected to find Bob when opening 'Bob', got %+v", openResult) + } + + // Remove specific observations + deleteObs := []Observation{ + { + EntityName: "Alice", + Observations: []string{"Works as developer"}, + }, + } + err = kb.deleteObservations(deleteObs) + if err != nil { + t.Fatalf("failed to delete observations: %v", err) + } + + // Confirm observation removal + graph, _ = kb.loadGraph() + aliceIndex := slices.IndexFunc(graph.Entities, func(e Entity) bool { + return e.Name == "Alice" + }) + if aliceIndex == -1 { + t.Errorf("entity 'Alice' not found after deleting observation") + } else { + alice := graph.Entities[aliceIndex] + if slices.Contains(alice.Observations, "Works as developer") { + t.Errorf("observation 'Works as developer' should have been deleted") + } + } + + // Remove relations + err = kb.deleteRelations(testRelations) + if err != nil { + t.Fatalf("failed to delete relations: %v", err) + } + + // Confirm relation removal + graph, _ = kb.loadGraph() + if len(graph.Relations) != 0 { + t.Errorf("expected 0 relations after deletion, got %d", len(graph.Relations)) + } + + // Remove entities + err = kb.deleteEntities([]string{"Alice"}) + if err != nil { + t.Fatalf("failed to delete entities: %v", err) + } + + // Confirm entity removal + graph, _ = kb.loadGraph() + if len(graph.Entities) != 1 || graph.Entities[0].Name != "Bob" { + t.Errorf("expected only Bob to remain after deleting Alice, got %+v", graph.Entities) + } + }) + } +} + +// TestSaveAndLoadGraph ensures data persists correctly across save/load cycles. +func TestSaveAndLoadGraph(t *testing.T) { + for name, newStore := range stores() { + t.Run(name, func(t *testing.T) { + s := newStore(t) + kb := knowledgeBase{s: s} + + // Setup test data + testGraph := KnowledgeGraph{ + Entities: []Entity{ + { + Name: "Charlie", + EntityType: "Person", + Observations: []string{"Likes hiking"}, + }, + }, + Relations: []Relation{ + { + From: "Charlie", + To: "Mountains", + RelationType: "enjoys", + }, + }, + } + + // Persist to storage + err := kb.saveGraph(testGraph) + if err != nil { + t.Fatalf("failed to save graph: %v", err) + } + + // Reload from storage + loadedGraph, err := kb.loadGraph() + if err != nil { + t.Fatalf("failed to load graph: %v", err) + } + + // Verify data integrity + if !reflect.DeepEqual(testGraph, loadedGraph) { + t.Errorf("loaded graph does not match saved graph.\nExpected: %+v\nGot: %+v", testGraph, loadedGraph) + } + + // Test malformed data handling + if fs, ok := s.(*fileStore); ok { + err := os.WriteFile(fs.path, []byte("invalid json"), 0o600) + if err != nil { + t.Fatalf("failed to write invalid json: %v", err) + } + + _, err = kb.loadGraph() + if err == nil { + t.Errorf("expected error when loading invalid JSON, got nil") + } + } + }) + } +} + +// TestDuplicateEntitiesAndRelations verifies duplicate prevention logic. +func TestDuplicateEntitiesAndRelations(t *testing.T) { + for name, newStore := range stores() { + t.Run(name, func(t *testing.T) { + s := newStore(t) + kb := knowledgeBase{s: s} + + // Setup initial state + initialEntities := []Entity{ + { + Name: "Dave", + EntityType: "Person", + Observations: []string{"Plays guitar"}, + }, + } + + _, err := kb.createEntities(initialEntities) + if err != nil { + t.Fatalf("failed to create initial entities: %v", err) + } + + // Attempt duplicate creation + duplicateEntities := []Entity{ + { + Name: "Dave", + EntityType: "Person", + Observations: []string{"Sings well"}, + }, + { + Name: "Eve", + EntityType: "Person", + Observations: []string{"Plays piano"}, + }, + } + + newEntities, err := kb.createEntities(duplicateEntities) + if err != nil { + t.Fatalf("failed when adding duplicate entities: %v", err) + } + + // Verify only new entities created + if len(newEntities) != 1 || newEntities[0].Name != "Eve" { + t.Errorf("expected only 'Eve' to be created, got %+v", newEntities) + } + + // Setup initial relation + initialRelation := []Relation{ + { + From: "Dave", + To: "Eve", + RelationType: "friend", + }, + } + + _, err = kb.createRelations(initialRelation) + if err != nil { + t.Fatalf("failed to create initial relation: %v", err) + } + + // Test relation deduplication + duplicateRelations := []Relation{ + { + From: "Dave", + To: "Eve", + RelationType: "friend", + }, + { + From: "Eve", + To: "Dave", + RelationType: "friend", + }, + } + + newRelations, err := kb.createRelations(duplicateRelations) + if err != nil { + t.Fatalf("failed when adding duplicate relations: %v", err) + } + + // Verify only new relations created + if len(newRelations) != 1 || newRelations[0].From != "Eve" || newRelations[0].To != "Dave" { + t.Errorf("expected only 'Eve->Dave' relation to be created, got %+v", newRelations) + } + }) + } +} + +// TestErrorHandling verifies proper error responses for invalid operations. +func TestErrorHandling(t *testing.T) { + t.Run("FileStoreWriteError", func(t *testing.T) { + // Test file write to invalid path + kb := knowledgeBase{ + s: &fileStore{path: filepath.Join("nonexistent", "directory", "file.json")}, + } + + testEntities := []Entity{ + {Name: "TestEntity"}, + } + + _, err := kb.createEntities(testEntities) + if err == nil { + t.Errorf("expected error when writing to non-existent directory, got nil") + } + }) + + for name, newStore := range stores() { + t.Run(fmt.Sprintf("AddObservationToNonExistentEntity_%s", name), func(t *testing.T) { + s := newStore(t) + kb := knowledgeBase{s: s} + + // Setup valid entity for comparison + _, err := kb.createEntities([]Entity{{Name: "RealEntity"}}) + if err != nil { + t.Fatalf("failed to create test entity: %v", err) + } + + // Test invalid entity reference + nonExistentObs := []Observation{ + { + EntityName: "NonExistentEntity", + Contents: []string{"This shouldn't work"}, + }, + } + + _, err = kb.addObservations(nonExistentObs) + if err == nil { + t.Errorf("expected error when adding observations to non-existent entity, got nil") + } + }) + } +} + +// TestFileFormatting verifies the JSON storage format structure. +func TestFileFormatting(t *testing.T) { + for name, newStore := range stores() { + t.Run(name, func(t *testing.T) { + s := newStore(t) + kb := knowledgeBase{s: s} + + // Setup test entity + testEntities := []Entity{ + { + Name: "FileTest", + EntityType: "TestEntity", + Observations: []string{"Test observation"}, + }, + } + + _, err := kb.createEntities(testEntities) + if err != nil { + t.Fatalf("failed to create test entity: %v", err) + } + + // Extract raw storage data + data, err := s.Read() + if err != nil { + t.Fatalf("failed to read from store: %v", err) + } + + // Validate JSON format + var items []kbItem + err = json.Unmarshal(data, &items) + if err != nil { + t.Fatalf("failed to parse store data JSON: %v", err) + } + + // Check data structure + if len(items) != 1 { + t.Fatalf("expected 1 item in memory file, got %d", len(items)) + } + + item := items[0] + if item.Type != "entity" || + item.Name != "FileTest" || + item.EntityType != "TestEntity" || + len(item.Observations) != 1 || + item.Observations[0] != "Test observation" { + t.Errorf("store item format incorrect: %+v", item) + } + }) + } +} + +// TestMCPServerIntegration tests the knowledge base through MCP server layer. +func TestMCPServerIntegration(t *testing.T) { + for name, newStore := range stores() { + t.Run(name, func(t *testing.T) { + s := newStore(t) + kb := knowledgeBase{s: s} + + // Create mock server session + ctx := context.Background() + + createResult, out, err := kb.CreateEntities(ctx, nil, CreateEntitiesArgs{ + Entities: []Entity{ + { + Name: "TestPerson", + EntityType: "Person", + Observations: []string{"Likes testing"}, + }, + }, + }) + if err != nil { + t.Fatalf("MCP CreateEntities failed: %v", err) + } + if createResult.IsError { + t.Fatalf("MCP CreateEntities returned error: %v", createResult.Content) + } + if len(out.Entities) != 1 { + t.Errorf("expected 1 entity created, got %d", len(out.Entities)) + } + + // Test ReadGraph through MCP + readResult, outg, err := kb.ReadGraph(ctx, nil, nil) + if err != nil { + t.Fatalf("MCP ReadGraph failed: %v", err) + } + if readResult.IsError { + t.Fatalf("MCP ReadGraph returned error: %v", readResult.Content) + } + if len(outg.Entities) != 1 { + t.Errorf("expected 1 entity in graph, got %d", len(outg.Entities)) + } + + relationsResult, outr, err := kb.CreateRelations(ctx, nil, CreateRelationsArgs{ + Relations: []Relation{ + { + From: "TestPerson", + To: "Testing", + RelationType: "likes", + }, + }, + }) + if err != nil { + t.Fatalf("MCP CreateRelations failed: %v", err) + } + if relationsResult.IsError { + t.Fatalf("MCP CreateRelations returned error: %v", relationsResult.Content) + } + if len(outr.Relations) != 1 { + t.Errorf("expected 1 relation created, got %d", len(outr.Relations)) + } + + obsResult, outo, err := kb.AddObservations(ctx, nil, AddObservationsArgs{ + Observations: []Observation{ + { + EntityName: "TestPerson", + Contents: []string{"Works remotely", "Drinks coffee"}, + }, + }, + }) + if err != nil { + t.Fatalf("MCP AddObservations failed: %v", err) + } + if obsResult.IsError { + t.Fatalf("MCP AddObservations returned error: %v", obsResult.Content) + } + if len(outo.Observations) != 1 { + t.Errorf("expected 1 observation result, got %d", len(outo.Observations)) + } + + searchResult, outg, err := kb.SearchNodes(ctx, nil, SearchNodesArgs{ + Query: "coffee", + }) + if err != nil { + t.Fatalf("MCP SearchNodes failed: %v", err) + } + if searchResult.IsError { + t.Fatalf("MCP SearchNodes returned error: %v", searchResult.Content) + } + if len(outg.Entities) != 1 { + t.Errorf("expected 1 entity from search, got %d", len(outg.Entities)) + } + + openResult, outg, err := kb.OpenNodes(ctx, nil, OpenNodesArgs{ + Names: []string{"TestPerson"}, + }) + if err != nil { + t.Fatalf("MCP OpenNodes failed: %v", err) + } + if openResult.IsError { + t.Fatalf("MCP OpenNodes returned error: %v", openResult.Content) + } + if len(outg.Entities) != 1 { + t.Errorf("expected 1 entity from open, got %d", len(outg.Entities)) + } + + deleteObsResult, _, err := kb.DeleteObservations(ctx, nil, DeleteObservationsArgs{ + Deletions: []Observation{ + { + EntityName: "TestPerson", + Observations: []string{"Works remotely"}, + }, + }, + }) + if err != nil { + t.Fatalf("MCP DeleteObservations failed: %v", err) + } + if deleteObsResult.IsError { + t.Fatalf("MCP DeleteObservations returned error: %v", deleteObsResult.Content) + } + + deleteRelResult, _, err := kb.DeleteRelations(ctx, nil, DeleteRelationsArgs{ + Relations: []Relation{ + { + From: "TestPerson", + To: "Testing", + RelationType: "likes", + }, + }, + }) + if err != nil { + t.Fatalf("MCP DeleteRelations failed: %v", err) + } + if deleteRelResult.IsError { + t.Fatalf("MCP DeleteRelations returned error: %v", deleteRelResult.Content) + } + + deleteEntResult, _, err := kb.DeleteEntities(ctx, nil, DeleteEntitiesArgs{ + EntityNames: []string{"TestPerson"}, + }) + if err != nil { + t.Fatalf("MCP DeleteEntities failed: %v", err) + } + if deleteEntResult.IsError { + t.Fatalf("MCP DeleteEntities returned error: %v", deleteEntResult.Content) + } + + // Verify final state + _, outg, err = kb.ReadGraph(ctx, nil, nil) + if err != nil { + t.Fatalf("Final MCP ReadGraph failed: %v", err) + } + if len(outg.Entities) != 0 { + t.Errorf("expected empty graph after deletion, got %d entities", len(outg.Entities)) + } + }) + } +} + +// TestMCPErrorHandling tests error scenarios through MCP layer. +func TestMCPErrorHandling(t *testing.T) { + for name, newStore := range stores() { + t.Run(name, func(t *testing.T) { + s := newStore(t) + kb := knowledgeBase{s: s} + + ctx := context.Background() + + _, _, err := kb.AddObservations(ctx, nil, AddObservationsArgs{ + Observations: []Observation{ + { + EntityName: "NonExistentEntity", + Contents: []string{"This should fail"}, + }, + }, + }) + if err == nil { + t.Errorf("expected MCP AddObservations to return error for non-existent entity") + } else { + // Verify the error message contains expected text + want := "entity with name NonExistentEntity not found" + if !strings.Contains(err.Error(), want) { + t.Errorf("expected error message to contain '%s', got: %v", want, err) + } + } + }) + } +} + +// TestMCPResponseFormat verifies MCP response format consistency. +func TestMCPResponseFormat(t *testing.T) { + s := &memoryStore{} + kb := knowledgeBase{s: s} + + ctx := context.Background() + + result, out, err := kb.CreateEntities(ctx, nil, CreateEntitiesArgs{ + Entities: []Entity{ + {Name: "FormatTest", EntityType: "Test"}, + }, + }) + if err != nil { + t.Fatalf("CreateEntities failed: %v", err) + } + + // Verify response has both Content and StructuredContent + if len(result.Content) == 0 { + t.Errorf("expected Content field to be populated") + } + if len(out.Entities) == 0 { + t.Errorf("expected StructuredContent.Entities to be populated") + } + + // Verify Content contains simple success message + if textContent, ok := result.Content[0].(*mcp.TextContent); ok { + expectedMessage := "Entities created successfully" + if textContent.Text != expectedMessage { + t.Errorf("expected Content field to contain '%s', got '%s'", expectedMessage, textContent.Text) + } + } else { + t.Errorf("expected Content[0] to be TextContent") + } +} diff --git a/examples/server/memory/main.go b/examples/server/memory/main.go new file mode 100644 index 00000000..99d109a6 --- /dev/null +++ b/examples/server/memory/main.go @@ -0,0 +1,145 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "flag" + "log" + "net/http" + "os" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") + memoryFilePath = flag.String("memory", "", "if set, persist the knowledge base to this file; otherwise, it will be stored in memory and lost on exit") +) + +// HiArgs defines arguments for the greeting tool. +type HiArgs struct { + Name string `json:"name"` +} + +// CreateEntitiesArgs defines the create entities tool parameters. +type CreateEntitiesArgs struct { + Entities []Entity `json:"entities" mcp:"entities to create"` +} + +// CreateEntitiesResult returns newly created entities. +type CreateEntitiesResult struct { + Entities []Entity `json:"entities"` +} + +// CreateRelationsArgs defines the create relations tool parameters. +type CreateRelationsArgs struct { + Relations []Relation `json:"relations" mcp:"relations to create"` +} + +// CreateRelationsResult returns newly created relations. +type CreateRelationsResult struct { + Relations []Relation `json:"relations"` +} + +// AddObservationsArgs defines the add observations tool parameters. +type AddObservationsArgs struct { + Observations []Observation `json:"observations" mcp:"observations to add"` +} + +// AddObservationsResult returns newly added observations. +type AddObservationsResult struct { + Observations []Observation `json:"observations"` +} + +// DeleteEntitiesArgs defines the delete entities tool parameters. +type DeleteEntitiesArgs struct { + EntityNames []string `json:"entityNames" mcp:"entities to delete"` +} + +// DeleteObservationsArgs defines the delete observations tool parameters. +type DeleteObservationsArgs struct { + Deletions []Observation `json:"deletions" mcp:"obeservations to delete"` +} + +// DeleteRelationsArgs defines the delete relations tool parameters. +type DeleteRelationsArgs struct { + Relations []Relation `json:"relations" mcp:"relations to delete"` +} + +// SearchNodesArgs defines the search nodes tool parameters. +type SearchNodesArgs struct { + Query string `json:"query" mcp:"query string"` +} + +// OpenNodesArgs defines the open nodes tool parameters. +type OpenNodesArgs struct { + Names []string `json:"names" mcp:"names of nodes to open"` +} + +func main() { + flag.Parse() + + // Initialize storage backend + var kbStore store + kbStore = &memoryStore{} + if *memoryFilePath != "" { + kbStore = &fileStore{path: *memoryFilePath} + } + kb := knowledgeBase{s: kbStore} + + // Setup MCP server with knowledge base tools + server := mcp.NewServer(&mcp.Implementation{Name: "memory"}, nil) + mcp.AddTool(server, &mcp.Tool{ + Name: "create_entities", + Description: "Create multiple new entities in the knowledge graph", + }, kb.CreateEntities) + mcp.AddTool(server, &mcp.Tool{ + Name: "create_relations", + Description: "Create multiple new relations between entities", + }, kb.CreateRelations) + mcp.AddTool(server, &mcp.Tool{ + Name: "add_observations", + Description: "Add new observations to existing entities", + }, kb.AddObservations) + mcp.AddTool(server, &mcp.Tool{ + Name: "delete_entities", + Description: "Remove entities and their relations", + }, kb.DeleteEntities) + mcp.AddTool(server, &mcp.Tool{ + Name: "delete_observations", + Description: "Remove specific observations from entities", + }, kb.DeleteObservations) + mcp.AddTool(server, &mcp.Tool{ + Name: "delete_relations", + Description: "Remove specific relations from the graph", + }, kb.DeleteRelations) + mcp.AddTool(server, &mcp.Tool{ + Name: "read_graph", + Description: "Read the entire knowledge graph", + }, kb.ReadGraph) + mcp.AddTool(server, &mcp.Tool{ + Name: "search_nodes", + Description: "Search for nodes based on query", + }, kb.SearchNodes) + mcp.AddTool(server, &mcp.Tool{ + Name: "open_nodes", + Description: "Retrieve specific nodes by name", + }, kb.OpenNodes) + + // Start server with appropriate transport + if *httpAddr != "" { + handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, nil) + log.Printf("MCP handler listening at %s", *httpAddr) + http.ListenAndServe(*httpAddr, handler) + } else { + t := &mcp.LoggingTransport{Transport: &mcp.StdioTransport{}, Writer: os.Stderr} + if err := server.Run(context.Background(), t); err != nil { + log.Printf("Server failed: %v", err) + } + } +} diff --git a/examples/server/middleware/main.go b/examples/server/middleware/main.go new file mode 100644 index 00000000..4e37471e --- /dev/null +++ b/examples/server/middleware/main.go @@ -0,0 +1,148 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "fmt" + "log/slog" + "os" + "time" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// This example demonstrates server side logging using the mcp.Middleware system. +func main() { + // Create a logger for demonstration purposes. + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelInfo, + ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr { + // Simplify timestamp format for consistent output. + if a.Key == slog.TimeKey { + return slog.String("time", "2025-01-01T00:00:00Z") + } + return a + }, + })) + + loggingMiddleware := func(next mcp.MethodHandler) mcp.MethodHandler { + return func( + ctx context.Context, + method string, + req mcp.Request, + ) (mcp.Result, error) { + logger.Info("MCP method started", + "method", method, + "session_id", req.GetSession().ID(), + "has_params", req.GetParams() != nil, + ) + // Log more for tool calls. + if ctr, ok := req.(*mcp.CallToolRequest); ok { + logger.Info("Calling tool", + "name", ctr.Params.Name, + "args", ctr.Params.Arguments) + } + + start := time.Now() + result, err := next(ctx, method, req) + duration := time.Since(start) + if err != nil { + logger.Error("MCP method failed", + "method", method, + "session_id", req.GetSession().ID(), + "duration_ms", duration.Milliseconds(), + "err", err, + ) + } else { + logger.Info("MCP method completed", + "method", method, + "session_id", req.GetSession().ID(), + "duration_ms", duration.Milliseconds(), + "has_result", result != nil, + ) + // Log more for tool results. + if ctr, ok := result.(*mcp.CallToolResult); ok { + logger.Info("tool result", + "isError", ctr.IsError, + "structuredContent", ctr.StructuredContent) + } + } + return result, err + } + } + + // Create server with middleware + server := mcp.NewServer(&mcp.Implementation{Name: "logging-example"}, nil) + server.AddReceivingMiddleware(loggingMiddleware) + + // Add a simple tool + mcp.AddTool(server, + &mcp.Tool{ + Name: "greet", + Description: "Greet someone with logging.", + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": { + Type: "string", + Description: "Name to greet", + }, + }, + Required: []string{"name"}, + }, + }, + func( + ctx context.Context, + req *mcp.CallToolRequest, args map[string]any, + ) (*mcp.CallToolResult, any, error) { + name, ok := args["name"].(string) + if !ok { + return nil, nil, fmt.Errorf("name parameter is required and must be a string") + } + + message := fmt.Sprintf("Hello, %s!", name) + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: message}, + }, + }, message, nil + }, + ) + + // Create client-server connection for demonstration + client := mcp.NewClient(&mcp.Implementation{Name: "test-client"}, nil) + clientTransport, serverTransport := mcp.NewInMemoryTransports() + + ctx := context.Background() + + // Connect server and client + serverSession, _ := server.Connect(ctx, serverTransport, nil) + defer serverSession.Close() + + clientSession, _ := client.Connect(ctx, clientTransport, nil) + defer clientSession.Close() + + // Call the tool to demonstrate logging + result, _ := clientSession.CallTool(ctx, &mcp.CallToolParams{ + Name: "greet", + Arguments: map[string]any{ + "name": "World", + }, + }) + + fmt.Printf("Tool result: %s\n", result.Content[0].(*mcp.TextContent).Text) + + // Output: + // time=2025-01-01T00:00:00Z level=INFO msg="MCP method started" method=initialize session_id="" has_params=true + // time=2025-01-01T00:00:00Z level=INFO msg="MCP method completed" method=initialize session_id="" duration_ms=0 has_result=true + // time=2025-01-01T00:00:00Z level=INFO msg="MCP method started" method=notifications/initialized session_id="" has_params=true + // time=2025-01-01T00:00:00Z level=INFO msg="MCP method completed" method=notifications/initialized session_id="" duration_ms=0 has_result=false + // time=2025-01-01T00:00:00Z level=INFO msg="MCP method started" method=tools/call session_id="" has_params=true + // time=2025-01-01T00:00:00Z level=INFO msg="Calling tool" name=greet args="{\"name\":\"World\"}" + // time=2025-01-01T00:00:00Z level=INFO msg="MCP method completed" method=tools/call session_id="" duration_ms=0 has_result=true + // Tool result: Hello, World! +} diff --git a/examples/server/rate-limiting/go.mod b/examples/server/rate-limiting/go.mod new file mode 100644 index 00000000..91d3d269 --- /dev/null +++ b/examples/server/rate-limiting/go.mod @@ -0,0 +1,15 @@ +module github.com/modelcontextprotocol/go-sdk/examples/rate-limiting + +go 1.23.0 + +require ( + github.com/modelcontextprotocol/go-sdk v0.3.0 + golang.org/x/time v0.12.0 +) + +require ( + github.com/google/jsonschema-go v0.2.3 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect +) + +replace github.com/modelcontextprotocol/go-sdk => ../../../ diff --git a/examples/rate-limiting/go.sum b/examples/server/rate-limiting/go.sum similarity index 56% rename from examples/rate-limiting/go.sum rename to examples/server/rate-limiting/go.sum index c7027682..da49fd16 100644 --- a/examples/rate-limiting/go.sum +++ b/examples/server/rate-limiting/go.sum @@ -1,7 +1,9 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/modelcontextprotocol/go-sdk v0.0.0-20250625185707-09181c2c2e89 h1:kUGBYP25FTv3ZRBhLT4iQvtx4FDl7hPkWe3isYrMxyo= -github.com/modelcontextprotocol/go-sdk v0.0.0-20250625185707-09181c2c2e89/go.mod h1:DcXfbr7yl7e35oMpzHfKw2nUYRjhIGS2uou/6tdsTB0= +github.com/google/jsonschema-go v0.2.3 h1:dkP3B96OtZKKFvdrUSaDkL+YDx8Uw9uC4Y+eukpCnmM= +github.com/google/jsonschema-go v0.2.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= diff --git a/examples/server/rate-limiting/main.go b/examples/server/rate-limiting/main.go new file mode 100644 index 00000000..e107183f --- /dev/null +++ b/examples/server/rate-limiting/main.go @@ -0,0 +1,96 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "errors" + "log" + "sync" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "golang.org/x/time/rate" +) + +// GlobalRateLimiterMiddleware creates a middleware that applies a global rate limit. +// Every request attempting to pass through will try to acquire a token. +// If a token cannot be acquired immediately, the request will be rejected. +func GlobalRateLimiterMiddleware(limiter *rate.Limiter) mcp.Middleware { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + if !limiter.Allow() { + return nil, errors.New("JSON RPC overloaded") + } + return next(ctx, method, req) + } + } +} + +// PerMethodRateLimiterMiddleware creates a middleware that applies rate limiting +// on a per-method basis. +// Methods not specified in limiters will not be rate limited by this middleware. +func PerMethodRateLimiterMiddleware(limiters map[string]*rate.Limiter) mcp.Middleware { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + if limiter, ok := limiters[method]; ok { + if !limiter.Allow() { + return nil, errors.New("JSON RPC overloaded") + } + } + return next(ctx, method, req) + } + } +} + +// PerSessionRateLimiterMiddleware creates a middleware that applies rate limiting +// on a per-session basis for receiving requests. +func PerSessionRateLimiterMiddleware(limit rate.Limit, burst int) mcp.Middleware { + // A map to store limiters, keyed by the session ID. + var ( + sessionLimiters = make(map[string]*rate.Limiter) + mu sync.Mutex + ) + + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + // It's possible that session.ID() may be empty at this point in time + // for some transports (e.g., stdio) or until the MCP initialize handshake + // has completed. + sessionID := req.GetSession().ID() + if sessionID == "" { + // In this situation, you could apply a single global identifier + // if session ID is empty or bypass the rate limiter. + // In this example, we bypass the rate limiter. + log.Printf("Warning: Session ID is empty for method %q. Skipping per-session rate limiting.", method) + return next(ctx, method, req) // Skip limiting if ID is unavailable + } + mu.Lock() + limiter, ok := sessionLimiters[sessionID] + if !ok { + limiter = rate.NewLimiter(limit, burst) + sessionLimiters[sessionID] = limiter + } + mu.Unlock() + if !limiter.Allow() { + return nil, errors.New("JSON RPC overloaded") + } + return next(ctx, method, req) + } + } +} + +func main() { + server := mcp.NewServer(&mcp.Implementation{Name: "greeter1", Version: "v0.0.1"}, nil) + server.AddReceivingMiddleware(GlobalRateLimiterMiddleware(rate.NewLimiter(rate.Every(time.Second/5), 10))) + server.AddReceivingMiddleware(PerMethodRateLimiterMiddleware(map[string]*rate.Limiter{ + "callTool": rate.NewLimiter(rate.Every(time.Second), 5), // once a second with a burst up to 5 + "listTools": rate.NewLimiter(rate.Every(time.Minute), 20), // once a minute with a burst up to 20 + })) + server.AddReceivingMiddleware(PerSessionRateLimiterMiddleware(rate.Every(time.Second/5), 10)) + // Run Server logic. + log.Println("MCP Server instance created with Middleware (but not running).") + log.Println("This example demonstrates configuration, not live interaction.") +} diff --git a/examples/server/sequentialthinking/README.md b/examples/server/sequentialthinking/README.md new file mode 100644 index 00000000..40987b15 --- /dev/null +++ b/examples/server/sequentialthinking/README.md @@ -0,0 +1,203 @@ +# Sequential Thinking MCP Server + +This example shows a Model Context Protocol (MCP) server that enables dynamic and reflective problem-solving through structured thinking processes. It helps break down complex problems into manageable, sequential thought steps with support for revision and branching. + +## Features + +The server provides three main tools for managing thinking sessions: + +### 1. Start Thinking (`start_thinking`) + +Begins a new sequential thinking session for a complex problem. + +**Parameters:** + +- `problem` (string): The problem or question to think about +- `sessionId` (string, optional): Custom session identifier +- `estimatedSteps` (int, optional): Initial estimate of thinking steps needed + +### 2. Continue Thinking (`continue_thinking`) + +Adds the next thought step, revises previous steps, or creates alternative branches. + +**Parameters:** + +- `sessionId` (string): The thinking session to continue +- `thought` (string): The current thought or analysis +- `nextNeeded` (bool, optional): Whether another thinking step is needed +- `reviseStep` (int, optional): Step number to revise (1-based) +- `createBranch` (bool, optional): Create an alternative reasoning path +- `estimatedTotal` (int, optional): Update total estimated steps + +### 3. Review Thinking (`review_thinking`) + +Provides a complete review of the thinking process for a session. + +**Parameters:** + +- `sessionId` (string): The session to review + +## Resources + +### Thinking History (`thinking://sessions` or `thinking://{sessionId}`) + +Access thinking session data and history in JSON format. + +- `thinking://sessions` - List all thinking sessions +- `thinking://{sessionId}` - Get specific session details + +## Core Concepts + +### Sequential Processing + +Problems are broken down into numbered thought steps that build upon each other, maintaining context and allowing for systematic analysis. + +### Dynamic Revision + +Any previous thought step can be revised and updated, with the system tracking which thoughts have been modified. + +### Alternative Branching + +Create alternative reasoning paths to explore different approaches to the same problem, allowing for comparative analysis. + +### Adaptive Planning + +The estimated number of thinking steps can be adjusted dynamically as understanding of the problem evolves. + +## Running the Server + +### Standard I/O Mode + +```bash +go run . +``` + +### HTTP Mode + +```bash +go run . -http :8080 +``` + +## Example Usage + +### Starting a Thinking Session + +```json +{ + "method": "tools/call", + "params": { + "name": "start_thinking", + "arguments": { + "problem": "How should I design a scalable microservices architecture?", + "sessionId": "architecture_design", + "estimatedSteps": 8 + } + } +} +``` + +### Adding Sequential Thoughts + +```json +{ + "method": "tools/call", + "params": { + "name": "continue_thinking", + "arguments": { + "sessionId": "architecture_design", + "thought": "First, I need to identify the core business domains and their boundaries to determine service decomposition." + } + } +} +``` + +### Revising a Previous Step + +```json +{ + "method": "tools/call", + "params": { + "name": "continue_thinking", + "arguments": { + "sessionId": "architecture_design", + "thought": "Actually, before identifying domains, I should analyze the current system's pain points and requirements.", + "reviseStep": 1 + } + } +} +``` + +### Creating an Alternative Branch + +```json +{ + "method": "tools/call", + "params": { + "name": "continue_thinking", + "arguments": { + "sessionId": "architecture_design", + "thought": "Alternative approach: Start with a monolith-first strategy and extract services gradually.", + "createBranch": true + } + } +} +``` + +### Completing the Thinking Process + +```json +{ + "method": "tools/call", + "params": { + "name": "continue_thinking", + "arguments": { + "sessionId": "architecture_design", + "thought": "Based on this analysis, I recommend starting with 3 core services: User Management, Order Processing, and Inventory Management.", + "nextNeeded": false + } + } +} +``` + +### Reviewing the Complete Process + +```json +{ + "method": "tools/call", + "params": { + "name": "review_thinking", + "arguments": { + "sessionId": "architecture_design" + } + } +} +``` + +## Session State Management + +Each thinking session maintains: + +- **Session metadata**: ID, problem statement, creation time, current status +- **Thought sequence**: Ordered list of thoughts with timestamps and revision history +- **Progress tracking**: Current step and estimated total steps +- **Branch relationships**: Links to alternative reasoning paths +- **Status management**: Active, completed, or paused sessions + +## Use Cases + +**Ideal for:** + +- Complex problem analysis requiring step-by-step breakdown +- Design decisions needing systematic evaluation +- Scenarios where initial scope is unclear and may evolve +- Problems requiring alternative approach exploration +- Situations needing detailed reasoning documentation + +**Examples:** + +- Software architecture design +- Research methodology planning +- Strategic business decisions +- Technical troubleshooting +- Creative problem solving +- Academic research planning diff --git a/examples/server/sequentialthinking/main.go b/examples/server/sequentialthinking/main.go new file mode 100644 index 00000000..e0ae5219 --- /dev/null +++ b/examples/server/sequentialthinking/main.go @@ -0,0 +1,538 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "crypto/rand" + "encoding/json" + "flag" + "fmt" + "log" + "maps" + "net/http" + "net/url" + "os" + "slices" + "strings" + "sync" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") + +// A Thought is a single step in the thinking process. +type Thought struct { + // Index of the thought within the session (1-based). + Index int `json:"index"` + // Content of the thought. + Content string `json:"content"` + // Time the thought was created. + Created time.Time `json:"created"` + // Whether the thought has been revised. + Revised bool `json:"revised"` + // Index of parent thought, or nil if this is a root for branching. + ParentIndex *int `json:"parentIndex,omitempty"` +} + +// A ThinkingSession is an active thinking session. +type ThinkingSession struct { + // Globally unique ID of the session. + ID string `json:"id"` + // Problem to solve. + Problem string `json:"problem"` + // Thoughts in the session. + Thoughts []*Thought `json:"thoughts"` + // Current thought index. + CurrentThought int `json:"currentThought"` + // Estimated total number of thoughts. + EstimatedTotal int `json:"estimatedTotal"` + // Status of the session. + Status string `json:"status"` // "active", "completed", "paused" + // Time the session was created. + Created time.Time `json:"created"` + // Time the session was last active. + LastActivity time.Time `json:"lastActivity"` + // Branches in the session. Alternative thought paths. + Branches []string `json:"branches,omitempty"` + // Version for optimistic concurrency control. + Version int `json:"version"` +} + +// clone returns a deep copy of the ThinkingSession. +func (s *ThinkingSession) clone() *ThinkingSession { + sessionCopy := *s + sessionCopy.Thoughts = deepCopyThoughts(s.Thoughts) + sessionCopy.Branches = slices.Clone(s.Branches) + return &sessionCopy +} + +// A SessionStore is a global session store (in a real implementation, this might be a database). +// +// Locking Strategy: +// The SessionStore uses a RWMutex to protect the sessions map from concurrent access. +// All ThinkingSession modifications happen on deep copies, never on shared instances. +// This means: +// - Read locks protect map access. +// - Write locks protect map modifications (adding/removing/replacing sessions) +// - Session field modifications always happen on local copies via CompareAndSwap +// - No shared ThinkingSession state is ever modified directly +type SessionStore struct { + mu sync.RWMutex + sessions map[string]*ThinkingSession // key is session ID +} + +// NewSessionStore creates a new session store for managing thinking sessions. +func NewSessionStore() *SessionStore { + return &SessionStore{ + sessions: make(map[string]*ThinkingSession), + } +} + +// Session retrieves a thinking session by ID, returning the session and whether it exists. +func (s *SessionStore) Session(id string) (*ThinkingSession, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + session, exists := s.sessions[id] + return session, exists +} + +// SetSession stores or updates a thinking session in the store. +func (s *SessionStore) SetSession(session *ThinkingSession) { + s.mu.Lock() + defer s.mu.Unlock() + s.sessions[session.ID] = session +} + +// CompareAndSwap atomically updates a session if the version matches. +// Returns true if the update succeeded, false if there was a version mismatch. +// +// This method implements optimistic concurrency control: +// 1. Read lock to safely access the map and copy the session +// 2. Deep copy the session (all modifications happen on this copy) +// 3. Release read lock and apply updates to the copy +// 4. Write lock to check version and atomically update if unchanged +// +// The read lock in step 1 is necessary to prevent map access races, +// not to protect ThinkingSession fields (which are never modified in-place). +func (s *SessionStore) CompareAndSwap(sessionID string, updateFunc func(*ThinkingSession) (*ThinkingSession, error)) error { + for { + // Get current session + s.mu.RLock() + current, exists := s.sessions[sessionID] + if !exists { + s.mu.RUnlock() + return fmt.Errorf("session %s not found", sessionID) + } + // Create a deep copy + sessionCopy := current.clone() + oldVersion := current.Version + s.mu.RUnlock() + + // Apply the update + updated, err := updateFunc(sessionCopy) + if err != nil { + return err + } + + // Try to save + s.mu.Lock() + current, exists = s.sessions[sessionID] + if !exists { + s.mu.Unlock() + return fmt.Errorf("session %s not found", sessionID) + } + if current.Version != oldVersion { + // Version mismatch, retry + s.mu.Unlock() + continue + } + updated.Version = oldVersion + 1 + s.sessions[sessionID] = updated + s.mu.Unlock() + return nil + } +} + +// Sessions returns all thinking sessions in the store. +func (s *SessionStore) Sessions() []*ThinkingSession { + s.mu.RLock() + defer s.mu.RUnlock() + return slices.Collect(maps.Values(s.sessions)) +} + +// SessionsSnapshot returns a deep copy of all sessions for safe concurrent access. +func (s *SessionStore) SessionsSnapshot() []*ThinkingSession { + s.mu.RLock() + defer s.mu.RUnlock() + + var sessions []*ThinkingSession + for _, session := range s.sessions { + sessions = append(sessions, session.clone()) + } + return sessions +} + +// SessionSnapshot returns a deep copy of a session for safe concurrent access. +// The second return value reports whether a session with the given id exists. +func (s *SessionStore) SessionSnapshot(id string) (*ThinkingSession, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + session, exists := s.sessions[id] + if !exists { + return nil, false + } + + return session.clone(), true +} + +var store = NewSessionStore() + +// StartThinkingArgs are the arguments for starting a new thinking session. +type StartThinkingArgs struct { + Problem string `json:"problem"` + SessionID string `json:"sessionId,omitempty"` + EstimatedSteps int `json:"estimatedSteps,omitempty"` +} + +// ContinueThinkingArgs are the arguments for continuing a thinking session. +type ContinueThinkingArgs struct { + SessionID string `json:"sessionId"` + Thought string `json:"thought"` + NextNeeded *bool `json:"nextNeeded,omitempty"` + ReviseStep *int `json:"reviseStep,omitempty"` + CreateBranch bool `json:"createBranch,omitempty"` + EstimatedTotal int `json:"estimatedTotal,omitempty"` +} + +// ReviewThinkingArgs are the arguments for reviewing a thinking session. +type ReviewThinkingArgs struct { + SessionID string `json:"sessionId"` +} + +// ThinkingHistoryArgs are the arguments for retrieving thinking history. +type ThinkingHistoryArgs struct { + SessionID string `json:"sessionId"` +} + +// deepCopyThoughts creates a deep copy of a slice of thoughts. +func deepCopyThoughts(thoughts []*Thought) []*Thought { + thoughtsCopy := make([]*Thought, len(thoughts)) + for i, t := range thoughts { + t2 := *t + thoughtsCopy[i] = &t2 + } + return thoughtsCopy +} + +// StartThinking begins a new sequential thinking session for a complex problem. +func StartThinking(ctx context.Context, _ *mcp.CallToolRequest, args StartThinkingArgs) (*mcp.CallToolResult, any, error) { + sessionID := args.SessionID + if sessionID == "" { + sessionID = randText() + } + + estimatedSteps := args.EstimatedSteps + if estimatedSteps == 0 { + estimatedSteps = 5 // Default estimate + } + + session := &ThinkingSession{ + ID: sessionID, + Problem: args.Problem, + EstimatedTotal: estimatedSteps, + Status: "active", + Created: time.Now(), + LastActivity: time.Now(), + } + + store.SetSession(session) + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{ + Text: fmt.Sprintf("Started thinking session '%s' for problem: %s\nEstimated steps: %d\nReady for your first thought.", + sessionID, args.Problem, estimatedSteps), + }, + }, + }, nil, nil +} + +// ContinueThinking adds the next thought step, revises a previous step, or creates a branch in the thinking process. +func ContinueThinking(ctx context.Context, req *mcp.CallToolRequest, args ContinueThinkingArgs) (*mcp.CallToolResult, any, error) { + // Handle revision of existing thought + if args.ReviseStep != nil { + err := store.CompareAndSwap(args.SessionID, func(session *ThinkingSession) (*ThinkingSession, error) { + stepIndex := *args.ReviseStep - 1 + if stepIndex < 0 || stepIndex >= len(session.Thoughts) { + return nil, fmt.Errorf("invalid step number: %d", *args.ReviseStep) + } + + session.Thoughts[stepIndex].Content = args.Thought + session.Thoughts[stepIndex].Revised = true + session.LastActivity = time.Now() + return session, nil + }) + if err != nil { + return nil, nil, err + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{ + Text: fmt.Sprintf("Revised step %d in session '%s':\n%s", + *args.ReviseStep, args.SessionID, args.Thought), + }, + }, + }, nil, nil + } + + // Handle branching + if args.CreateBranch { + var branchID string + var branchSession *ThinkingSession + + err := store.CompareAndSwap(args.SessionID, func(session *ThinkingSession) (*ThinkingSession, error) { + branchID = fmt.Sprintf("%s_branch_%d", args.SessionID, len(session.Branches)+1) + session.Branches = append(session.Branches, branchID) + session.LastActivity = time.Now() + + // Create a new session for the branch (deep copy thoughts) + thoughtsCopy := deepCopyThoughts(session.Thoughts) + branchSession = &ThinkingSession{ + ID: branchID, + Problem: session.Problem + " (Alternative branch)", + Thoughts: thoughtsCopy, + CurrentThought: len(session.Thoughts), + EstimatedTotal: session.EstimatedTotal, + Status: "active", + Created: time.Now(), + LastActivity: time.Now(), + } + + return session, nil + }) + if err != nil { + return nil, nil, err + } + + // Save the branch session + store.SetSession(branchSession) + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{ + Text: fmt.Sprintf("Created branch '%s' from session '%s'. You can now continue thinking in either session.", + branchID, args.SessionID), + }, + }, + }, nil, nil + } + + // Add new thought + var thoughtID int + var progress string + var statusMsg string + + err := store.CompareAndSwap(args.SessionID, func(session *ThinkingSession) (*ThinkingSession, error) { + thoughtID = len(session.Thoughts) + 1 + thought := &Thought{ + Index: thoughtID, + Content: args.Thought, + Created: time.Now(), + Revised: false, + } + + session.Thoughts = append(session.Thoughts, thought) + session.CurrentThought = thoughtID + session.LastActivity = time.Now() + + // Update estimated total if provided + if args.EstimatedTotal > 0 { + session.EstimatedTotal = args.EstimatedTotal + } + + // Check if thinking is complete + if args.NextNeeded != nil && !*args.NextNeeded { + session.Status = "completed" + } + + // Prepare response strings + progress = fmt.Sprintf("Step %d", thoughtID) + if session.EstimatedTotal > 0 { + progress += fmt.Sprintf(" of ~%d", session.EstimatedTotal) + } + + if session.Status == "completed" { + statusMsg = "\n✓ Thinking process completed!" + } else { + statusMsg = "\nReady for next thought..." + } + + return session, nil + }) + if err != nil { + return nil, nil, err + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{ + Text: fmt.Sprintf("Session '%s' - %s:\n%s%s", + args.SessionID, progress, args.Thought, statusMsg), + }, + }, + }, nil, nil +} + +// ReviewThinking provides a complete review of the thinking process for a session. +func ReviewThinking(ctx context.Context, req *mcp.CallToolRequest, args ReviewThinkingArgs) (*mcp.CallToolResult, any, error) { + // Get a snapshot of the session to avoid race conditions + sessionSnapshot, exists := store.SessionSnapshot(args.SessionID) + if !exists { + return nil, nil, fmt.Errorf("session %s not found", args.SessionID) + } + + var review strings.Builder + fmt.Fprintf(&review, "=== Thinking Review: %s ===\n", sessionSnapshot.ID) + fmt.Fprintf(&review, "Problem: %s\n", sessionSnapshot.Problem) + fmt.Fprintf(&review, "Status: %s\n", sessionSnapshot.Status) + fmt.Fprintf(&review, "Steps: %d of ~%d\n", len(sessionSnapshot.Thoughts), sessionSnapshot.EstimatedTotal) + + if len(sessionSnapshot.Branches) > 0 { + fmt.Fprintf(&review, "Branches: %s\n", strings.Join(sessionSnapshot.Branches, ", ")) + } + + fmt.Fprintf(&review, "\n--- Thought Sequence ---\n") + + for i, thought := range sessionSnapshot.Thoughts { + status := "" + if thought.Revised { + status = " (revised)" + } + fmt.Fprintf(&review, "%d. %s%s\n", i+1, thought.Content, status) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{ + Text: review.String(), + }, + }, + }, nil, nil +} + +// ThinkingHistory handles resource requests for thinking session data and history. +func ThinkingHistory(ctx context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + // Extract session ID from URI (e.g., "thinking://session_123") + u, err := url.Parse(req.Params.URI) + if err != nil { + return nil, fmt.Errorf("invalid thinking resource URI: %s", req.Params.URI) + } + if u.Scheme != "thinking" { + return nil, fmt.Errorf("invalid thinking resource URI scheme: %s", u.Scheme) + } + + sessionID := u.Host + if sessionID == "sessions" { + // List all sessions - use snapshot for thread safety + sessions := store.SessionsSnapshot() + data, err := json.MarshalIndent(sessions, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to marshal sessions: %w", err) + } + + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{ + { + URI: req.Params.URI, + MIMEType: "application/json", + Text: string(data), + }, + }, + }, nil + } + + // Get specific session - use snapshot for thread safety + session, exists := store.SessionSnapshot(sessionID) + if !exists { + return nil, fmt.Errorf("session %s not found", sessionID) + } + + data, err := json.MarshalIndent(session, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to marshal session: %w", err) + } + + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{ + { + URI: req.Params.URI, + MIMEType: "application/json", + Text: string(data), + }, + }, + }, nil +} + +// Copied from crypto/rand. +// TODO: once 1.24 is assured, just use crypto/rand. +const base32alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" + +func randText() string { + // ⌈log₃₂ 2¹²⁸⌉ = 26 chars + src := make([]byte, 26) + rand.Read(src) + for i := range src { + src[i] = base32alphabet[src[i]%32] + } + return string(src) +} + +func main() { + flag.Parse() + + server := mcp.NewServer(&mcp.Implementation{Name: "sequential-thinking"}, nil) + + mcp.AddTool(server, &mcp.Tool{ + Name: "start_thinking", + Description: "Begin a new sequential thinking session for a complex problem", + }, StartThinking) + + mcp.AddTool(server, &mcp.Tool{ + Name: "continue_thinking", + Description: "Add the next thought step, revise a previous step, or create a branch", + }, ContinueThinking) + + mcp.AddTool(server, &mcp.Tool{ + Name: "review_thinking", + Description: "Review the complete thinking process for a session", + }, ReviewThinking) + + server.AddResource(&mcp.Resource{ + Name: "thinking_sessions", + Description: "Access thinking session data and history", + URI: "thinking://sessions", + MIMEType: "application/json", + }, ThinkingHistory) + + if *httpAddr != "" { + handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, nil) + log.Printf("Sequential Thinking MCP server listening at %s", *httpAddr) + if err := http.ListenAndServe(*httpAddr, handler); err != nil { + log.Fatal(err) + } + } else { + t := &mcp.LoggingTransport{Transport: &mcp.StdioTransport{}, Writer: os.Stderr} + if err := server.Run(context.Background(), t); err != nil { + log.Printf("Server failed: %v", err) + } + } +} diff --git a/examples/server/sequentialthinking/main_test.go b/examples/server/sequentialthinking/main_test.go new file mode 100644 index 00000000..2655114c --- /dev/null +++ b/examples/server/sequentialthinking/main_test.go @@ -0,0 +1,491 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "encoding/json" + "strings" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func TestStartThinking(t *testing.T) { + // Reset store for clean test + store = NewSessionStore() + + ctx := context.Background() + + args := StartThinkingArgs{ + Problem: "How to implement a binary search algorithm", + SessionID: "test_session", + EstimatedSteps: 5, + } + + result, _, err := StartThinking(ctx, nil, args) + if err != nil { + t.Fatalf("StartThinking() error = %v", err) + } + + if len(result.Content) == 0 { + t.Fatal("No content in result") + } + + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + if !strings.Contains(textContent.Text, "test_session") { + t.Error("Result should contain session ID") + } + + if !strings.Contains(textContent.Text, "How to implement a binary search algorithm") { + t.Error("Result should contain the problem statement") + } + + // Verify session was stored + session, exists := store.Session("test_session") + if !exists { + t.Fatal("Session was not stored") + } + + if session.Problem != args.Problem { + t.Errorf("Expected problem %s, got %s", args.Problem, session.Problem) + } + + if session.EstimatedTotal != 5 { + t.Errorf("Expected estimated total 5, got %d", session.EstimatedTotal) + } + + if session.Status != "active" { + t.Errorf("Expected status 'active', got %s", session.Status) + } +} + +func TestContinueThinking(t *testing.T) { + // Reset store and create initial session + store = NewSessionStore() + + // First start a thinking session + ctx := context.Background() + startArgs := StartThinkingArgs{ + Problem: "Test problem", + SessionID: "test_continue", + EstimatedSteps: 3, + } + + _, _, err := StartThinking(ctx, nil, startArgs) + if err != nil { + t.Fatalf("StartThinking() error = %v", err) + } + + // Now continue thinking + continueArgs := ContinueThinkingArgs{ + SessionID: "test_continue", + Thought: "First thought: I need to understand the problem", + } + + result, _, err := ContinueThinking(ctx, nil, continueArgs) + if err != nil { + t.Fatalf("ContinueThinking() error = %v", err) + } + + // Verify result + if len(result.Content) == 0 { + t.Fatal("No content in result") + } + + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + if !strings.Contains(textContent.Text, "Step 1") { + t.Error("Result should contain step number") + } + + // Verify session was updated + session, exists := store.Session("test_continue") + if !exists { + t.Fatal("Session not found") + } + + if len(session.Thoughts) != 1 { + t.Errorf("Expected 1 thought, got %d", len(session.Thoughts)) + } + + if session.Thoughts[0].Content != continueArgs.Thought { + t.Errorf("Expected thought content %s, got %s", continueArgs.Thought, session.Thoughts[0].Content) + } + + if session.CurrentThought != 1 { + t.Errorf("Expected current thought 1, got %d", session.CurrentThought) + } +} + +func TestContinueThinkingWithCompletion(t *testing.T) { + // Reset store and create initial session + store = NewSessionStore() + + ctx := context.Background() + startArgs := StartThinkingArgs{ + Problem: "Simple test", + SessionID: "test_completion", + } + + _, _, err := StartThinking(ctx, nil, startArgs) + if err != nil { + t.Fatalf("StartThinking() error = %v", err) + } + + // Continue with completion flag + nextNeeded := false + continueArgs := ContinueThinkingArgs{ + SessionID: "test_completion", + Thought: "Final thought", + NextNeeded: &nextNeeded, + } + + result, _, err := ContinueThinking(ctx, nil, continueArgs) + if err != nil { + t.Fatalf("ContinueThinking() error = %v", err) + } + + // Check completion message + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + if !strings.Contains(textContent.Text, "completed") { + t.Error("Result should indicate completion") + } + + // Verify session status + session, exists := store.Session("test_completion") + if !exists { + t.Fatal("Session not found") + } + + if session.Status != "completed" { + t.Errorf("Expected status 'completed', got %s", session.Status) + } +} + +func TestContinueThinkingRevision(t *testing.T) { + // Setup session with existing thoughts + store = NewSessionStore() + session := &ThinkingSession{ + ID: "test_revision", + Problem: "Test problem", + Thoughts: []*Thought{ + {Index: 1, Content: "Original thought", Created: time.Now()}, + {Index: 2, Content: "Second thought", Created: time.Now()}, + }, + CurrentThought: 2, + EstimatedTotal: 3, + Status: "active", + Created: time.Now(), + LastActivity: time.Now(), + } + store.SetSession(session) + + ctx := context.Background() + reviseStep := 1 + continueArgs := ContinueThinkingArgs{ + SessionID: "test_revision", + Thought: "Revised first thought", + ReviseStep: &reviseStep, + } + + result, _, err := ContinueThinking(ctx, nil, continueArgs) + if err != nil { + t.Fatalf("ContinueThinking() error = %v", err) + } + + // Verify revision message + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + if !strings.Contains(textContent.Text, "Revised step 1") { + t.Error("Result should indicate revision") + } + + // Verify thought was revised + updatedSession, _ := store.Session("test_revision") + if updatedSession.Thoughts[0].Content != "Revised first thought" { + t.Error("First thought should be revised") + } + + if !updatedSession.Thoughts[0].Revised { + t.Error("First thought should be marked as revised") + } +} + +// func TestContinueThinkingBranching(t *testing.T) { +// // Setup session with existing thoughts +// store = NewSessionStore() +// session := &ThinkingSession{ +// ID: "test_branch", +// Problem: "Test problem", +// Thoughts: []*Thought{ +// {Index: 1, Content: "First thought", Created: time.Now()}, +// }, +// CurrentThought: 1, +// EstimatedTotal: 3, +// Status: "active", +// Created: time.Now(), +// LastActivity: time.Now(), +// Branches: []string{}, +// } +// store.SetSession(session) + +// ctx := context.Background() +// continueArgs := ContinueThinkingArgs{ +// SessionID: "test_branch", +// Thought: "Alternative approach", +// CreateBranch: true, +// } + +// continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ +// Name: "continue_thinking", +// Arguments: continueArgs, +// } + +// // Verify branch creation message +// textContent, ok := result.Content[0].(*mcp.TextContent) +// if !ok { +// t.Fatal("Expected TextContent") +// } + +// if !strings.Contains(textContent.Text, "Created branch") { +// t.Error("Result should indicate branch creation") +// } + +// // Verify branch was created +// updatedSession, _ := store.Session("test_branch") +// if len(updatedSession.Branches) != 1 { +// t.Errorf("Expected 1 branch, got %d", len(updatedSession.Branches)) +// } + +// branchID := updatedSession.Branches[0] +// if !strings.Contains(branchID, "test_branch_branch_") { +// t.Error("Branch ID should contain parent session ID") +// } + +// // Verify branch session exists +// branchSession, exists := store.Session(branchID) +// if !exists { +// t.Fatal("Branch session should exist") +// } + +// if len(branchSession.Thoughts) != 1 { +// t.Error("Branch should inherit parent thoughts") +// } +// } + +func TestReviewThinking(t *testing.T) { + // Setup session with thoughts + store = NewSessionStore() + session := &ThinkingSession{ + ID: "test_review", + Problem: "Complex problem", + Thoughts: []*Thought{ + {Index: 1, Content: "First thought", Created: time.Now(), Revised: false}, + {Index: 2, Content: "Second thought", Created: time.Now(), Revised: true}, + {Index: 3, Content: "Final thought", Created: time.Now(), Revised: false}, + }, + CurrentThought: 3, + EstimatedTotal: 3, + Status: "completed", + Created: time.Now(), + LastActivity: time.Now(), + Branches: []string{"test_review_branch_1"}, + } + store.SetSession(session) + + ctx := context.Background() + reviewArgs := ReviewThinkingArgs{ + SessionID: "test_review", + } + + result, _, err := ReviewThinking(ctx, nil, reviewArgs) + if err != nil { + t.Fatalf("ReviewThinking() error = %v", err) + } + + // Verify review content + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatal("Expected TextContent") + } + + reviewText := textContent.Text + + if !strings.Contains(reviewText, "test_review") { + t.Error("Review should contain session ID") + } + + if !strings.Contains(reviewText, "Complex problem") { + t.Error("Review should contain problem") + } + + if !strings.Contains(reviewText, "completed") { + t.Error("Review should contain status") + } + + if !strings.Contains(reviewText, "Steps: 3 of ~3") { + t.Error("Review should contain step count") + } + + if !strings.Contains(reviewText, "First thought") { + t.Error("Review should contain first thought") + } + + if !strings.Contains(reviewText, "(revised)") { + t.Error("Review should indicate revised thoughts") + } + + if !strings.Contains(reviewText, "test_review_branch_1") { + t.Error("Review should list branches") + } +} + +func TestThinkingHistory(t *testing.T) { + // Setup test sessions + store = NewSessionStore() + session1 := &ThinkingSession{ + ID: "session1", + Problem: "Problem 1", + Thoughts: []*Thought{{Index: 1, Content: "Thought 1", Created: time.Now()}}, + CurrentThought: 1, + EstimatedTotal: 2, + Status: "active", + Created: time.Now(), + LastActivity: time.Now(), + } + session2 := &ThinkingSession{ + ID: "session2", + Problem: "Problem 2", + Thoughts: []*Thought{{Index: 1, Content: "Thought 1", Created: time.Now()}}, + CurrentThought: 1, + EstimatedTotal: 3, + Status: "completed", + Created: time.Now(), + LastActivity: time.Now(), + } + store.SetSession(session1) + store.SetSession(session2) + + ctx := context.Background() + + // Test listing all sessions + result, err := ThinkingHistory(ctx, &mcp.ReadResourceRequest{ + Params: &mcp.ReadResourceParams{ + URI: "thinking://sessions", + }, + }) + if err != nil { + t.Fatalf("ThinkingHistory() error = %v", err) + } + + if len(result.Contents) != 1 { + t.Fatal("Expected 1 content item") + } + + content := result.Contents[0] + if content.MIMEType != "application/json" { + t.Error("Expected JSON MIME type") + } + + // Parse and verify sessions list + var sessions []*ThinkingSession + err = json.Unmarshal([]byte(content.Text), &sessions) + if err != nil { + t.Fatalf("Failed to parse sessions JSON: %v", err) + } + + if len(sessions) != 2 { + t.Errorf("Expected 2 sessions, got %d", len(sessions)) + } + + // Test getting specific session + result, err = ThinkingHistory(ctx, &mcp.ReadResourceRequest{ + Params: &mcp.ReadResourceParams{URI: "thinking://session1"}, + }) + if err != nil { + t.Fatalf("ThinkingHistory() error = %v", err) + } + + var retrievedSession ThinkingSession + err = json.Unmarshal([]byte(result.Contents[0].Text), &retrievedSession) + if err != nil { + t.Fatalf("Failed to parse session JSON: %v", err) + } + + if retrievedSession.ID != "session1" { + t.Errorf("Expected session ID 'session1', got %s", retrievedSession.ID) + } + + if retrievedSession.Problem != "Problem 1" { + t.Errorf("Expected problem 'Problem 1', got %s", retrievedSession.Problem) + } +} + +func TestInvalidOperations(t *testing.T) { + store = NewSessionStore() + ctx := context.Background() + + // Test continue thinking with non-existent session + continueArgs := ContinueThinkingArgs{ + SessionID: "nonexistent", + Thought: "Some thought", + } + + _, _, err := ContinueThinking(ctx, nil, continueArgs) + if err == nil { + t.Error("Expected error for non-existent session") + } + + // Test review with non-existent session + reviewArgs := ReviewThinkingArgs{ + SessionID: "nonexistent", + } + + _, _, err = ReviewThinking(ctx, nil, reviewArgs) + if err == nil { + t.Error("Expected error for non-existent session in review") + } + + // Test invalid revision step + session := &ThinkingSession{ + ID: "test_invalid", + Problem: "Test", + Thoughts: []*Thought{{Index: 1, Content: "Thought", Created: time.Now()}}, + CurrentThought: 1, + EstimatedTotal: 2, + Status: "active", + Created: time.Now(), + LastActivity: time.Now(), + } + store.SetSession(session) + + reviseStep := 5 // Invalid step number + invalidReviseArgs := ContinueThinkingArgs{ + SessionID: "test_invalid", + Thought: "Revised", + ReviseStep: &reviseStep, + } + + _, _, err = ContinueThinking(ctx, nil, invalidReviseArgs) + if err == nil { + t.Error("Expected error for invalid revision step") + } +} diff --git a/examples/server/sse/main.go b/examples/server/sse/main.go new file mode 100644 index 00000000..0507dd60 --- /dev/null +++ b/examples/server/sse/main.go @@ -0,0 +1,70 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "flag" + "fmt" + "log" + "net/http" + "os" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + host = flag.String("host", "localhost", "host to listen on") + port = flag.String("port", "8080", "port to listen on") +) + +type SayHiParams struct { + Name string `json:"name"` +} + +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args SayHiParams) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Hi " + args.Name}, + }, + }, nil, nil +} + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s [options]\n\n", os.Args[0]) + fmt.Fprintf(os.Stderr, "This program runs MCP servers over SSE HTTP.\n\n") + fmt.Fprintf(os.Stderr, "Options:\n") + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, "\nEndpoints:\n") + fmt.Fprintf(os.Stderr, " /greeter1 - Greeter 1 service\n") + fmt.Fprintf(os.Stderr, " /greeter2 - Greeter 2 service\n") + os.Exit(1) + } + flag.Parse() + + addr := fmt.Sprintf("%s:%s", *host, *port) + + server1 := mcp.NewServer(&mcp.Implementation{Name: "greeter1"}, nil) + mcp.AddTool(server1, &mcp.Tool{Name: "greet1", Description: "say hi"}, SayHi) + + server2 := mcp.NewServer(&mcp.Implementation{Name: "greeter2"}, nil) + mcp.AddTool(server2, &mcp.Tool{Name: "greet2", Description: "say hello"}, SayHi) + + log.Printf("MCP servers serving at %s", addr) + handler := mcp.NewSSEHandler(func(request *http.Request) *mcp.Server { + url := request.URL.Path + log.Printf("Handling request for URL %s\n", url) + switch url { + case "/greeter1": + return server1 + case "/greeter2": + return server2 + default: + return nil + } + }, nil) + log.Fatal(http.ListenAndServe(addr, handler)) +} diff --git a/examples/server/toolschemas/main.go b/examples/server/toolschemas/main.go new file mode 100644 index 00000000..36ee0495 --- /dev/null +++ b/examples/server/toolschemas/main.go @@ -0,0 +1,146 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The toolschemas example demonstrates how to create tools using both the +// low-level [ToolHandler] and high level [ToolHandlerFor], as well as how to +// customize schemas in both cases. +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// Input is the input into all the tools handlers below. +type Input struct { + Name string `json:"name" jsonschema:"the person to greet"` +} + +// Output is the structured output of the tool. +// +// Not every tool needs to have structured output. +type Output struct { + Greeting string `json:"greeting" jsonschema:"the greeting to send to the user"` +} + +// simpleGreeting is an [mcp.ToolHandlerFor] that only cares about input and output. +func simpleGreeting(_ context.Context, _ *mcp.CallToolRequest, input Input) (*mcp.CallToolResult, Output, error) { + return nil, Output{"Hi " + input.Name}, nil +} + +// manualGreeter handles the parsing and validation of input and output manually. +// +// Therfore, it needs to close over its resolved schemas, to use them in +// validation. +type manualGreeter struct { + inputSchema *jsonschema.Resolved + outputSchema *jsonschema.Resolved +} + +func (t *manualGreeter) greet(_ context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // errf produces a 'tool error', embedding the error in a CallToolResult. + errf := func(format string, args ...any) *mcp.CallToolResult { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: fmt.Sprintf(format, args...)}}, + IsError: true, + } + } + // Handle the parsing and validation of input and output. + // + // Note that errors here are treated as tool errors, not protocol errors. + var input Input + if err := json.Unmarshal(req.Params.Arguments, &input); err != nil { + return errf("failed to unmarshal arguments: %v", err), nil + } + if err := t.inputSchema.Validate(input); err != nil { + return errf("invalid input: %v", err), nil + } + output := Output{Greeting: "Hi " + input.Name} + if err := t.outputSchema.Validate(output); err != nil { + return errf("tool produced invalid output: %v", err), nil + } + outputJSON, err := json.Marshal(output) + if err != nil { + return errf("output failed to marshal: %v", err), nil + } + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: string(outputJSON)}}, + StructuredContent: output, + }, nil +} + +func main() { + server := mcp.NewServer(&mcp.Implementation{Name: "greeter"}, nil) + + // Add the 'greeting' tool in a few different ways. + + // First, we can just use [mcp.AddTool], and get the out-of-the-box handling + // it provides: + mcp.AddTool(server, &mcp.Tool{Name: "simple greeting"}, simpleGreeting) + + // Next, we can create our schemas entirely manually, and add them using + // [mcp.Server.AddTool]. Since we're working manually, we can add some + // constraints on the length of the name. + // + // We don't need to do all this work: below, we use jsonschema.For to start + // from the default schema. + var ( + manual manualGreeter + err error + ) + inputSchema := &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string", MaxLength: jsonschema.Ptr(10)}, + }, + } + manual.inputSchema, err = inputSchema.Resolve(nil) + if err != nil { + log.Fatal(err) + } + outputSchema := &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "greeting": {Type: "string"}, + }, + } + manual.outputSchema, err = outputSchema.Resolve(nil) + if err != nil { + log.Fatal(err) + } + server.AddTool(&mcp.Tool{ + Name: "manual greeting", + InputSchema: inputSchema, + OutputSchema: outputSchema, + }, manual.greet) + + // Finally, note that we can also use custom schemas with a ToolHandlerFor. + // We can do this in two ways: by using one of the schema values constructed + // above, or by using jsonschema.For and adjusting the resulting schema. + mcp.AddTool(server, &mcp.Tool{ + Name: "customized greeting 1", + InputSchema: inputSchema, + // OutputSchema will still be derived from Output. + }, simpleGreeting) + + customSchema, err := jsonschema.For[Input](nil) + if err != nil { + log.Fatal(err) + } + customSchema.Properties["name"].MaxLength = jsonschema.Ptr(10) + mcp.AddTool(server, &mcp.Tool{ + Name: "customized greeting 2", + InputSchema: customSchema, + }, simpleGreeting) + + // Now run the server. + if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { + log.Printf("Server failed: %v", err) + } +} diff --git a/examples/sse/main.go b/examples/sse/main.go deleted file mode 100644 index 97ea1bd0..00000000 --- a/examples/sse/main.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package main - -import ( - "context" - "flag" - "log" - "net/http" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -var httpAddr = flag.String("http", "", "use SSE HTTP at this address") - -type SayHiParams struct { - Name string `json:"name"` -} - -func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[SayHiParams]) (*mcp.CallToolResultFor[any], error) { - return &mcp.CallToolResultFor[any]{ - Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + params.Arguments.Name}, - }, - }, nil -} - -func main() { - flag.Parse() - - if httpAddr == nil || *httpAddr == "" { - log.Fatal("http address not set") - } - - server1 := mcp.NewServer("greeter1", "v0.0.1", nil) - server1.AddTools(mcp.NewServerTool("greet1", "say hi", SayHi)) - - server2 := mcp.NewServer("greeter2", "v0.0.1", nil) - server2.AddTools(mcp.NewServerTool("greet2", "say hello", SayHi)) - - log.Printf("MCP servers serving at %s\n", *httpAddr) - handler := mcp.NewSSEHandler(func(request *http.Request) *mcp.Server { - url := request.URL.Path - log.Printf("Handling request for URL %s\n", url) - switch url { - case "/greeter1": - return server1 - case "/greeter2": - return server2 - default: - return nil - } - }) - http.ListenAndServe(*httpAddr, handler) -} diff --git a/go.mod b/go.mod index 24e187ab..f5c578cf 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,9 @@ module github.com/modelcontextprotocol/go-sdk go 1.23.0 require ( + github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/go-cmp v0.7.0 + github.com/google/jsonschema-go v0.2.3 + github.com/yosida95/uritemplate/v3 v3.0.2 golang.org/x/tools v0.34.0 ) diff --git a/go.sum b/go.sum index 6c6c2a5d..6a392638 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,10 @@ +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.2.3 h1:dkP3B96OtZKKFvdrUSaDkL+YDx8Uw9uC4Y+eukpCnmM= +github.com/google/jsonschema-go v0.2.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= diff --git a/internal/docs/README.src.md b/internal/docs/README.src.md new file mode 100644 index 00000000..b252f943 --- /dev/null +++ b/internal/docs/README.src.md @@ -0,0 +1,39 @@ +# Features + +These docs mirror the [official MCP spec](https://modelcontextprotocol.io/specification/2025-06-18). +Use the index below to learn how the SDK implements a particular aspect of the +protocol. + +## Base Protocol + +1. [Lifecycle (Clients, Servers, and Sessions)](protocol.md#lifecycle). +1. [Transports](protocol.md#transports) + 1. [Stdio transport](protocol.md#stdio-transport) + 1. [Streamable transport](protocol.md#streamable-transport) + 1. [Custom transports](protocol.md#stateless-mode) +1. [Authorization](protocol.md#authorization) +1. [Security](protocol.md#security) +1. [Utilities](protocol.md#utilities) + 1. [Cancellation](utilities.md#cancellation) + 1. [Ping](utilities.md#ping) + 1. [Progress](utilities.md#progress) + +## Client Features + +1. [Roots](client.md#roots) +1. [Sampling](client.md#sampling) +1. [Elicitation](clients.md#elicitation) + +## Server Features + +1. [Prompts](server.md#prompts) +1. [Resources](server.md#resources) +1. [Tools](tools.md) +1. [Utilities](server.md#utilities) + 1. [Completion](server.md#completion) + 1. [Logging](server.md#logging) + 1. [Pagination](server.md#pagination) + +# TroubleShooting + +See [troubleshooting.md](troubleshooting.md) for a troubleshooting guide. diff --git a/internal/docs/client.src.md b/internal/docs/client.src.md new file mode 100644 index 00000000..fc37d454 --- /dev/null +++ b/internal/docs/client.src.md @@ -0,0 +1,55 @@ +# Support for MCP client features + +%toc + +## Roots + +MCP allows clients to specify a set of filesystem +["roots"](https://modelcontextprotocol.io/specification/2025-06-18/client/roots). +The SDK supports this as follows: + +**Client-side**: The SDK client always has the `roots.listChanged` capability. +To add roots to a client, use the +[`Client.AddRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client.AddRoots) +and +[`Client.RemoveRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client.RemoveRoots) +methods. If any servers are already [connected](protocol.md#lifecycle) to the +client, a call to `AddRoot` or `RemoveRoots` will result in a +`notifications/roots/list_changed` notification to each connected server. + +**Server-side**: To query roots from the server, use the +[`ServerSession.ListRoots`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.ListRoots) +method. To receive notifications about root changes, set +[`ServerOptions.RootsListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.RootsListChangedHandler). + +%include ../../mcp/client_example_test.go roots - + +## Sampling + +[Sampling](https://modelcontextprotocol.io/specification/2025-06-18/client/sampling) +is a way for servers to leverage the client's AI capabilities. It is +implemented in the SDK as follows: + +**Client-side**: To add the `sampling` capability to a client, set +[`ClientOptions.CreateMessageHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.CreateMessageHandler). +This function is invoked whenever the server requests sampling. + +**Server-side**: To use sampling from the server, call +[`ServerSession.CreateMessage`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.CreateMessage). + +%include ../../mcp/client_example_test.go sampling - + +## Elicitation + +[Elicitation](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation) +allows servers to request user inputs. It is implemented in the SDK as follows: + +**Client-side**: To add the `elicitation` capability to a client, set +[`ClientOptions.ElicitationHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ElicitationHandler). +The elicitation handler must return a result that matches the requested schema; +otherwise, elicitation returns an error. + +**Server-side**: To use elicitation from the server, call +[`ServerSession.Elicit`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.Elicit). + +%include ../../mcp/client_example_test.go elicitation - diff --git a/internal/docs/doc.go b/internal/docs/doc.go new file mode 100644 index 00000000..7b23ad63 --- /dev/null +++ b/internal/docs/doc.go @@ -0,0 +1,15 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:generate -command weave go run golang.org/x/example/internal/cmd/weave@latest +//go:generate weave -o ../../docs/README.md ./README.src.md +//go:generate weave -o ../../docs/protocol.md ./protocol.src.md +//go:generate weave -o ../../docs/client.md ./client.src.md +//go:generate weave -o ../../docs/server.md ./server.src.md +//go:generate weave -o ../../docs/troubleshooting.md ./troubleshooting.src.md + +// The doc package generates the documentation at /doc, via go:generate. +// +// Tests in this package are used for examples. +package docs diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md new file mode 100644 index 00000000..88c023cc --- /dev/null +++ b/internal/docs/protocol.src.md @@ -0,0 +1,281 @@ +# Support for the MCP base protocol + +%toc + +## Lifecycle + +The SDK provides an API for defining both MCP clients and servers, and +connecting them over various transports. When a client and server are +connected, it creates a logical session, which follows the MCP spec's +[lifecycle](https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle). + +In this SDK, both a +[`Client`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client) +and +[`Server`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server) +can handle multiple peers. Every time a new peer is connected, it creates a new +session. + +- A `Client` is a logical MCP client, configured with various + [`ClientOptions`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions). +- When a client is connected to a server using + [`Client.Connect`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Client.Connect), + it creates a + [`ClientSession`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession). + This session is initialized during the `Connect` method, and provides methods + to communicate with the server peer. +- A `Server` is a logical MCP server, configured with various + [`ServerOptions`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions). +- When a server is connected to a client using + [`Server.Connect`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.Connect), + it creates a + [`ServerSession`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession). + This session is not initialized until the client sends the + `notifications/initialized` message. Use `ServerOptions.InitializedHandler` + to listen for this event, or just use the session through various feature + handlers (such as a `ToolHandler`). Requests to the server are rejected until + the client has initialized the session. + +Both `ClientSession` and `ServerSession` have a `Close` method to terminate the +session, and a `Wait` method to await session termination by the peer. Typically, +it is the client's responsibility to end the session. + +%include ../../mcp/mcp_example_test.go lifecycle - + +## Transports + +A +[transport](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports) +can be used to send JSON-RPC messages from client to server, or vice-versa. + +In the SDK, this is achieved by implementing the +[`Transport`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Transport) +interface, which creates a (logical) bidirectional stream of JSON-RPC messages. +Most transport implementations described below are specific to either the +client or server: a "client transport" is something that can be used to connect +a client to a server, and a "server transport" is something that can be used to +connect a server to a client. However, it's possible for a transport to be both +a client and server transport, such as the `InMemoryTransport` used in the +lifecycle example above. + +Transports should not be reused for multiple connections: if you need to create +multiple connections, use different transports. + +### Stdio Transport + +In the +[`stdio`](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#stdio) +transport clients communicate with an MCP server running in a subprocess using +newline-delimited JSON over its stdin/stdout. + +**Client-side**: the client side of the `stdio` transport is implemented by +[`CommandTransport`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#CommandTransport), +which starts the a `exec.Cmd` as a subprocess and communicates over its +stdin/stdout. + +**Server-side**: the server side of the `stdio` transport is implemented by +`StdioTransport`, which connects over the current processes `os.Stdin` and +`os.Stdout`. + +### Streamable Transport + +The [streamable +transport](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http) +API is implemented across three types: + +- `StreamableHTTPHandler`: an`http.Handler` that serves streamable MCP + sessions. +- `StreamableServerTransport`: a `Transport` that implements the server side of + the streamable transport. +- `StreamableClientTransport`: a `Transport` that implements the client side of + the streamable transport. + +To create a streamable MCP server, you create a `StreamableHTTPHandler` and +pass it an `mcp.Server`: + +%include ../../mcp/streamable_example_test.go streamablehandler - + +The `StreamableHTTPHandler` handles the HTTP requests and creates a new +`StreamableServerTransport` for each new session. The transport is then used to +communicate with the client. + +On the client side, you create a `StreamableClientTransport` and use it to +connect to the server: + +```go +transport := &mcp.StreamableClientTransport{ + Endpoint: "http://localhost:8080/mcp", +} +client, err := mcp.Connect(ctx, transport, &mcp.ClientOptions{...}) +``` + +The `StreamableClientTransport` handles the HTTP requests and communicates with +the server using the streamable transport protocol. + +#### Stateless Mode + +The streamable server supports a _stateless mode_ by setting +[`StreamableHTTPOptions.Stateless`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#StreamableHTTPOptions.Stateless), +which is where the server does not perform any validation of the session id, +and uses a temporary session to handle requests. In this mode, it is impossible +for the server to make client requests, as there is no way for the client's +response to reach the session. + +However, it is still possible for the server to access the `ServerSession.ID` +to see the logical session + +> [!WARNING] +> Stateless mode is not directly discussed in the spec, and is still being +> defined. See modelcontextprotocol/modelcontextprotocol#1364, +> modelcontextprotocol/modelcontextprotocol#1372, or +> modelcontextprotocol/modelcontextprotocol#11442 for potential refinements. + +_See [examples/server/distributed](../examples/server/distributed/main.go) for +an example using statless mode to implement a server distributed across +multiple processes._ + +### Custom transports + +The SDK supports [custom +transports](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#custom-transports) +by implementing the +[`Transport`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Transport) +interface: a logical bidirectional stream of JSON-RPC messages. + +_Full example: [examples/server/custom-transport](../examples/server/custom-transport/main.go)._ + +### Concurrency + +In general, MCP offers no guarantees about concurrency semantics: if a client +or server sends a notification, the spec says nothing about when the peer +observes that notification relative to other request. However, the Go SDK +implements the following heuristics: + +- If a notifying method (such as `notifications/progress` or + `notifications/initialized`) returns, then it is guaranteed that the peer + observes that notification before other notifications or calls from the same + client goroutine. +- Calls (such as `tools/call`) are handled asynchronously with respect to + each other. + +See +[modelcontextprotocol/go-sdk#26](https://github.com/modelcontextprotocol/go-sdk/issues/26) +for more background. + +## Authorization + +### Server + +To write an MCP server that performs authorization, +use [`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken). +This function is middleware that wraps an HTTP handler, such as the one returned +by [`NewStreamableHTTPHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#NewStreamableHTTPHandler), to provide support for verifying bearer tokens. +The middleware function checks every request for an Authorization header with a bearer token, +and invokes the +[`TokenVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenVerifier) + passed to `RequireBearerToken` to parse the token and perform validation. +The middleware function checks expiration and scopes (if they are provided in +[`RequireBearerTokenOptions.Scopes`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerTokenOptions.Scopes)), so the +`TokenVerifer` doesn't have to. +If [`RequireBearerTokenOptions.ResourceMetadataURL`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerTokenOptions.ResourceMetadataURL) is set and verification fails, +the middleware function sets the WWW-Authenticate header as required by the [Protected Resource +Metadata spec](https://datatracker.ietf.org/doc/html/rfc9728). + +Server handlers, such as tool handlers, can obtain the `TokenInfo` returned by the `TokenVerifier` +from `req.Extra.TokenInfo`, where `req` is the handler's request. (For example, a +[`CallToolRequest`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#CallToolRequest).) +HTTP handlers wrapped by the `RequireBearerToken` middleware can obtain the `TokenInfo` from the context +with [`auth.TokenInfoFromContext`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenInfoFromContext). + + +The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/tree/main/examples/server/auth-middleware) shows how to implement authorization for both JWT tokens and API keys. + +### Client + +Client-side OAuth is implemented by setting +[`StreamableClientTransport.HTTPClient`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk@v0.5.0/mcp#StreamableClientTransport.HTTPClient) to a custom [`http.Client`](https://pkg.go.dev/net/http#Client) +Additional support is forthcoming; see #493. + +## Security + +Here we discuss the mitigations described under +the MCP spec's [Security Best Practices](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices) section, and how we handle them. + +### Confused Deputy + +The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), obtaining user consent for dynamically registered clients, +happens on the MCP client. At present we don't provide client-side OAuth support. + + +### Token Passthrough + +The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-2), accepting only tokens that were issued for the server, depends on the structure +of tokens and is the responsibility of the +[`TokenVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#TokenVerifier) +provided to +[`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken). + +### Session Hijacking + +The [mitigations](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation-3) are as follows: + +- _Verify all inbound requests_. The [`RequireBearerToken`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#RequireBearerToken) +middleware function will verify all HTTP requests that it receives. It is the +user's responsibility to wrap that function around all handlers in their server. + +- _Secure session IDs_. This SDK generates cryptographically secure session IDs by default. +If you create your own with +[`ServerOptions.GetSessionID`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.GetSessionID), it is your responsibility to ensure they are secure. +If you are using Go 1.24 or above, +we recommend using [`crypto/rand.Text`](https://pkg.go.dev/crypto/rand#Text) + +- _Binding session IDs to user information_. This is an application requirement, out of scope +for the SDK. You can create your own session IDs by setting +[`ServerOptions.GetSessionID`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.GetSessionID). + +## Utilities + +### Cancellation + +Cancellation is implemented with context cancellation. Cancelling a context +used in a method on `ClientSession` or `ServerSession` will terminate the RPC +and send a "notifications/cancelled" message to the peer. + +When an RPC exits due to a cancellation error, there's a guarantee that the +cancellation notification has been sent, but there's no guarantee that the +server has observed it (see [concurrency](#concurrency)). + +%include ../../mcp/mcp_example_test.go cancellation - + +### Ping + +[Ping](https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/ping) +support is symmetrical for client and server. + +To initiate a ping, call +[`ClientSession.Ping`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Ping) +or +[`ServerSession.Ping`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.Ping). + +To have the client or server session automatically ping its peer, and close the +session if the ping fails, set +[`ClientOptions.KeepAlive`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.KeepAlive) +or +[`ServerOptions.KeepAlive`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.KeepAlive). + +### Progress + +[Progress](https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/progress) +reporting is possible by reading the progress token from request metadata and +calling either +[`ClientSession.NotifyProgress`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.NotifyProgress) +or +[`ServerSession.NotifyProgress`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.NotifyProgress). +To listen to progress notifications, set +[`ClientOptions.ProgressNotificationHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ProgressNotificationHandler) +or +[`ServerOptions.ProgressNotificationHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.ProgressNotificationHandler). + +Issue #460 discusses some potential ergonomic improvements to this API. + +%include ../../mcp/mcp_example_test.go progress - diff --git a/internal/docs/server.src.md b/internal/docs/server.src.md new file mode 100644 index 00000000..c09ba63e --- /dev/null +++ b/internal/docs/server.src.md @@ -0,0 +1,272 @@ +# Support for MCP server features + +%toc + +## Prompts + +MCP servers can provide LLM prompt templates (called simply +[_prompts_](https://modelcontextprotocol.io/specification/2025-06-18/server/prompts)) +to clients. Every prompt has a required name which identifies it, and a set of +named arguments, which are strings. + +**Client-side**: To list the server's prompts, use the +[`ClientSession.Prompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Prompts) +iterator, or the lower-level +[`ClientSession.ListPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ListPrompts) +(see [pagination](#pagination) below). Set +[`ClientOptions.PromptListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.PromptListChangedHandler) +to be notified of changes in the list of prompts. + +Call +[`ClientSession.GetPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.GetPrompt) +to retrieve a prompt by name, providing arguments for expansion. + +**Server-side**: Use +[`Server.AddPrompt`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddPrompt) +to add a prompt to the server along with its handler. +The server will have the `prompts` capability if any prompt is added before the +server is connected to a client, or if +[`ServerOptions.HasPrompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasPrompts) +is explicitly set. When a prompt is added, any clients already connected to the +server will be notified via a `notifications/prompts/list_changed` +notification. + +%include ../../mcp/server_example_test.go prompts - + +## Resources + +In MCP terms, a _resource_ is some data referenced by a URI. +MCP servers can serve resources to clients. +They can register resources individually, or register a _resource template_ +that uses a URI pattern to describe a collection of resources. + + +**Client-side**: +Call [`ClientSession.ReadResource`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ReadResource) +to read a resource. +The SDK ensures that a read succeeds only if the URI matches a registered resource exactly, +or matches the URI pattern of a resource template. + +To list a server's resources and resource templates, use the +[`ClientSession.Resources`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Resources) +and +[`ClientSession.ResourceTemplates`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ResourceTemplates) +iterators, or the lower-level `ListXXX` calls (see [pagination](#pagination)). +Set +[`ClientOptions.ResourceListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ResourceListChangedHandler) +to be notified of changes in the lists of resources or resource templates. + +Clients can be notified when the contents of a resource changes by subscribing to the resource's URI. +Call +[`ClientSession.Subscribe`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Subscribe) +to subscribe to a resource +and +[`ClientSession.Unsubscribe`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Unsubscribe) +to unsubscribe. +Set +[`ClientOptions.ResourceUpdatedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ResourceUpdatedHandler) +to be notified of changes to subscribed resources. + +**Server-side**: +Use +[`Server.AddResource`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddResource) +or +[`Server.AddResourceTemplate`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddResourceTemplate) +to add a resource or resource template to the server along with its handler. +A +[`ResourceHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ResourceHandler) +maps a URI to the contents of a resource, which can include text, binary data, +or both. +If `AddResource` or `AddResourceTemplate` is called before a server is connected, the server will have the +`resources` capability. +The server will have the `resources` capability if any resource or resource template is added before the +server is connected to a client, or if +[`ServerOptions.HasResources`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasResources) +is explicitly set. When a prompt is added, any clients already connected to the +server will be notified via a `notifications/resources/list_changed` +notification. + + +%include ../../mcp/server_example_test.go resources - + +## Tools + +MCP servers can provide +[tools](https://modelcontextprotocol.io/specification/2025-06-18/server/tools) +to allow clients to interact with external systems or functionality. Tools are +effectively remote function calls, and the Go SDK provides mechanisms to bind +them to ordinary Go functions. + +**Client-side**: To list the server's tools, use the +[`ClientSession.Tools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Tools) +iterator, or the lower-level +[`ClientSession.ListTools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ListTools) +(see [pagination](#pagination)). Set +[`ClientOptions.ToolListChangedHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.ToolListChangedHandler) +to be notified of changes in the list of tools. + +To call a tool, use +[`ClientSession.CallTool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.CallTool) +with `CallToolParams` holding the name and arguments of the tool to call. + +```go +res, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "my_tool", + Arguments: map[string]any{"name": "user"}, +}) +``` + +Arguments may be any value that can be marshaled to JSON. + +**Server-side**: the basic API for adding a tool is symmetrical with the API +for prompts or resources: +[`Server.AddTool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Server.AddTool) +adds a +[`Tool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#Tool) to +the server along with its +[`ToolHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ToolHandler) +to handle it. The server will have the `tools` capability if any tool is added +before the server is connected to a client, or if +[`ServerOptions.HasTools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.HasPrompts) +is explicitly set. When a tool is added, any clients already connected to the +server will be notified via a `notifications/tools/list_changed` notification. + +However, the `Server.AddTool` API leaves it to the user to implement the tool +handler correctly according to the spec, providing very little out of the box. +In order to implement a tool, the user must do all of the following: + +- Provide a tool input and output schema. +- Validate the tool arguments against its input schema. +- Unmarshal the input schema into a Go value +- Execute the tool logic. +- Marshal the tool's structured output (if any) to JSON, and store it in the + result's `StructuredOutput` field as well as the unstructured `Content` field. +- Validate that output JSON against the tool's output schema. +- If any tool errors occurred, pack them into the unstructured content and set + `IsError` to `true.` + +For this reason, the SDK provides a generic +[`AddTool`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#AddTool) +function that handles this for you. It can bind a tool to any function with the +following shape: + +```go +func(_ context.Context, request *CallToolRequest, input In) (result *CallToolResult, output Out, _ error) +``` + +This is like a `ToolHandler`, but with an extra arbitrary `In` input parameter, +and `Out` output parameter. + +Such a function can then be bound to the server using `AddTool`: + +```go +mcp.AddTool(server, &mcp.Tool{Name: "my_tool"}, handler) +``` + +This does the following automatically: + +- If `Tool.InputSchema` or `Tool.OutputSchema` are unset, the input and output + schemas are inferred from the `In` type, which must be a struct or map. + Optional `jsonschema` struct tags provide argument descriptions. +- Tool arguments are validated against the input schema. +- Tool arguments are marshaled into the `In` value. +- Tool output (the `Out` value) is marshaled into the result's + `StructuredOutput`, as well as the unstructured `Content`. +- Output is validated against the tool's output schema. +- If an ordinary error is returned, it is stored int the `CallToolResult` and + `IsError` is set to `true`. + +In fact, under ordinary circumstances, the user can ignore `CallToolRequest` +and `CallToolResult`. + +For a more realistic example, consider a tool that retrieves the weather: + +%include ../../mcp/tool_example_test.go weathertool - + +In this case, we want to customize part of the inferred schema, though we can +still infer the rest. Since we want to control the inference ourselves, we set +the `Tool.InputSchema` explicitly: + +%include ../../mcp/tool_example_test.go customschemas - + +_See [mcp/tool_example_test.go](../mcp/tool_example_test.go) for the full +example, or [examples/server/toolschemas](examples/server/toolschemas/main.go) +for more examples of customizing tool schemas._ + +## Utilities + +### Completion + +To support the +[completion](https://modelcontextprotocol.io/specification/2025-06-18/server/utilities/completion) +capability, the server needs a completion handler. + +**Client-side**: completion is called using the +[`ClientSession.Complete`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Complete) +method. + +**Server-side**: completion is enabled by setting +[`ServerOptions.CompletionHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.CompletionHandler). +If this field is set to a non-nil value, the server will advertise the +`completions` server capability, and use this handler to respond to completion +requests. + +%include ../../examples/server/completion/main.go completionhandler - + +### Logging + +MCP servers can send logging messages to MCP clients. +(This form of logging is distinct from server-side logging, where the +server produces logs that remain server-side, for use by server maintainers.) + +**Server-side**: +The minimum log level is part of the server state. +For stateful sessions, there is no default log level: no log messages will be sent +until the client calls `SetLevel` (see below). +For stateful sessions, the level defaults to "info". + +[`ServerSession.Log`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerSession.Log) is the low-level way for servers to log to clients. +It sends a logging notification to the client if the level of the message +is at least the minimum log level. + +For a simpler API, use [`NewLoggingHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#NewLoggingHandler) to obtain a [`slog.Handler`](https://pkg.go.dev/log/slog#Handler). +By setting [`LoggingHandlerOptions.MinInterval`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#LoggingHandlerOptions.MinInterval), the handler can be rate-limited +to avoid spamming clients with too many messages. + +Servers always report the logging capability. + + +**Client-side**: +Set [`ClientOptions.LoggingMessageHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientOptions.LoggingMessageHandler) to receive log messages. + +Call [`ClientSession.SetLevel`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.SetLevel) to change the log level for a session. + +%include ../../mcp/server_example_test.go logging - + +### Pagination + +Server-side feature lists may be +[paginated](https://modelcontextprotocol.io/specification/2025-06-18/server/utilities/pagination), +using cursors. The SDK supports this by default. + +**Client-side**: The `ClientSession` provides methods returning +[iterators](https://go.dev/blog/range-functions) for each feature type. +These iterators are an `iter.Seq2[Feature, error]`, where the error value +indicates whether page retrieval failed. + +- [`ClientSession.Prompts`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Prompts) + iterates prompts. +- [`ClientSession.Resource`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Resource) + iterates resources. +- [`ClientSession.ResourceTemplates`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.ResourceTemplates) + iterates resource templates. +- [`ClientSession.Tools`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ClientSession.Tools) + iterates tools. + +The `ClientSession` also exposes `ListXXX` methods for fine-grained control +over pagination. + +**Server-side**: pagination is on by default, so in general nothing is required +server-side. However, you may use +[`ServerOptions.PageSize`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#ServerOptions.PageSize) +to customize the page size. diff --git a/internal/docs/troubleshooting.src.md b/internal/docs/troubleshooting.src.md new file mode 100644 index 00000000..83342032 --- /dev/null +++ b/internal/docs/troubleshooting.src.md @@ -0,0 +1,42 @@ +# Troubleshooting + +The Model Context Protocol is a complicated spec that leaves some room for +interpretation. Client and server SDKs can behave differently, or can be more +or less strict about their inputs. And of course, bugs happen. + +When you encounter a problem using the Go SDK, these instructions can help +collect information that will be useful in debugging. Please try to provide +this information in a bug report, so that maintainers can more quickly +understand what's going wrong. + +And most of all, please do [file bugs](https://github.com/modelcontextprotocol/go-sdk/issues/new?template=bug_report.md). + +## Using the MCP inspector + +To debug an MCP server, you can use the [MCP +inspector](https://modelcontextprotocol.io/legacy/tools/inspector). This is +useful for testing your server and verifying that it works with the typescript +SDK, as well as inspecting MCP traffic. + +## Collecting MCP logs + +For [stdio](protocol.md#stdio-transport) transport connections, you can also +inspect MCP traffic using a `LoggingTransport`: + +%include ../../mcp/transport_example_test.go loggingtransport - + +That example uses a `bytes.Buffer`, but you can also log to a file, or to +`os.Stderr`. + +## Inspecting HTTP traffic + +There are a couple different ways to investigate traffic to an HTTP transport +([streamable](protocol.md#streamable-transport) or legacy SSE). + +The first is to use an HTTP middleware: + +%include ../../mcp/streamable_example_test.go httpmiddleware - + +The second is to use a general purpose tool to inspect http traffic, such as +[wireshark](https://www.wireshark.org/) or +[tcpdump](https://linux.die.net/man/8/tcpdump). diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 32239454..5549ee1c 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -65,8 +65,7 @@ type Connection struct { state inFlightState // accessed only in updateInFlight done chan struct{} // closed (under stateMu) when state.closed is true and all goroutines have completed - writer chan Writer // 1-buffered; stores the writer when not in use - + writer Writer handler Handler onInternalError func(error) @@ -125,7 +124,7 @@ func (c *Connection) updateInFlight(f func(*inFlightState)) { // that and avoided making any updates that would cause the state to be // non-idle.) if !s.idle() { - panic("jsonrpc2_v2: updateInFlight transitioned to non-idle when already done") + panic("jsonrpc2: updateInFlight transitioned to non-idle when already done") } return default: @@ -214,12 +213,11 @@ func NewConnection(ctx context.Context, cfg ConnectionConfig) *Connection { c := &Connection{ state: inFlightState{closer: cfg.Closer}, done: make(chan struct{}), - writer: make(chan Writer, 1), + writer: cfg.Writer, onDone: cfg.OnDone, onInternalError: cfg.OnInternalError, } c.handler = cfg.Bind(c) - c.writer <- cfg.Writer c.start(ctx, cfg.Reader, cfg.Preempter) return c } @@ -239,7 +237,6 @@ func bindConnection(bindCtx context.Context, rwc io.ReadWriteCloser, binder Bind c := &Connection{ state: inFlightState{closer: rwc}, done: make(chan struct{}), - writer: make(chan Writer, 1), onDone: onDone, } // It's tempting to set a finalizer on c to verify that the state has gone @@ -259,7 +256,7 @@ func bindConnection(bindCtx context.Context, rwc io.ReadWriteCloser, binder Bind } c.onInternalError = options.OnInternalError - c.writer <- framer.Writer(rwc) + c.writer = framer.Writer(rwc) reader := framer.Reader(rwc) c.start(ctx, reader, options.Preempter) return c @@ -377,6 +374,46 @@ func (c *Connection) Call(ctx context.Context, method string, params any) *Async return ac } +// Async, signals that the current jsonrpc2 request may be handled +// asynchronously to subsequent requests, when ctx is the request context. +// +// Async must be called at most once on each request's context (and its +// descendants). +func Async(ctx context.Context) { + if r, ok := ctx.Value(asyncKey).(*releaser); ok { + r.release(false) + } +} + +type asyncKeyType struct{} + +var asyncKey = asyncKeyType{} + +// A releaser implements concurrency safe 'releasing' of async requests. (A +// request is released when it is allowed to run concurrent with other +// requests, via a call to [Async].) +type releaser struct { + mu sync.Mutex + ch chan struct{} + released bool +} + +// release closes the associated channel. If soft is set, multiple calls to +// release are allowed. +func (r *releaser) release(soft bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.released { + if !soft { + panic("jsonrpc2.Async called multiple times") + } + } else { + close(r.ch) + r.released = true + } +} + type AsyncCall struct { id ID ready chan struct{} // closed after response has been set @@ -428,28 +465,6 @@ func (ac *AsyncCall) Await(ctx context.Context, result any) error { return json.Unmarshal(ac.response.Result, result) } -// Respond delivers a response to an incoming Call. -// -// Respond must be called exactly once for any message for which a handler -// returns ErrAsyncResponse. It must not be called for any other message. -func (c *Connection) Respond(id ID, result any, err error) error { - var req *incomingRequest - c.updateInFlight(func(s *inFlightState) { - req = s.incomingByID[id] - }) - if req == nil { - return c.internalErrorf("Request not found for ID %v", id) - } - - if err == ErrAsyncResponse { - // Respond is supposed to supply the asynchronous response, so it would be - // confusing to call Respond with an error that promises to call Respond - // again. - err = c.internalErrorf("Respond called with ErrAsyncResponse for %q", req.Method) - } - return c.processResult("Respond", req, result, err) -} - // Cancel cancels the Context passed to the Handle call for the inbound message // with the given ID. // @@ -468,10 +483,32 @@ func (c *Connection) Cancel(id ID) { // Wait blocks until the connection is fully closed, but does not close it. func (c *Connection) Wait() error { + return c.wait(true) +} + +// wait for the connection to close, and aggregates the most cause of its +// termination, if abnormal. +// +// The fromWait argument allows this logic to be shared with Close, where we +// only want to expose the closeErr. +// +// (Previously, Wait also only returned the closeErr, which was misleading if +// the connection was broken for another reason). +func (c *Connection) wait(fromWait bool) error { var err error <-c.done c.updateInFlight(func(s *inFlightState) { - err = s.closeErr + if fromWait { + if !errors.Is(s.readErr, io.EOF) { + err = s.readErr + } + if err == nil && !errors.Is(s.writeErr, io.EOF) { + err = s.writeErr + } + } + if err == nil { + err = s.closeErr + } }) return err } @@ -487,8 +524,7 @@ func (c *Connection) Close() error { // Stop handling new requests, and interrupt the reader (by closing the // connection) as soon as the active requests finish. c.updateInFlight(func(s *inFlightState) { s.connClosing = true }) - - return c.Wait() + return c.wait(false) } // readIncoming collects inbound messages from the reader and delivers them, either responding @@ -579,11 +615,6 @@ func (c *Connection) acceptRequest(ctx context.Context, msg *Request, preempter if preempter != nil { result, err := preempter.Preempt(req.ctx, req.Request) - if req.IsCall() && errors.Is(err, ErrAsyncResponse) { - // This request will remain in flight until Respond is called for it. - return - } - if !errors.Is(err, ErrNotHandled) { c.processResult("Preempt", req, result, err) return @@ -658,19 +689,20 @@ func (c *Connection) handleAsync() { continue } - result, err := c.handler.Handle(req.ctx, req.Request) - c.processResult(c.handler, req, result, err) + releaser := &releaser{ch: make(chan struct{})} + ctx := context.WithValue(req.ctx, asyncKey, releaser) + go func() { + defer releaser.release(true) + result, err := c.handler.Handle(ctx, req.Request) + c.processResult(c.handler, req, result, err) + }() + <-releaser.ch } } // processResult processes the result of a request and, if appropriate, sends a response. func (c *Connection) processResult(from any, req *incomingRequest, result any, err error) error { switch err { - case ErrAsyncResponse: - if !req.IsCall() { - return c.internalErrorf("%#v returned ErrAsyncResponse for a %q Request without an ID", from, req.Method) - } - return nil // This request is still in flight, so don't record the result yet. case ErrNotHandled, ErrMethodNotFound: // Add detail describing the unhandled method. err = fmt.Errorf("%w: %q", ErrMethodNotFound, req.Method) @@ -708,17 +740,17 @@ func (c *Connection) processResult(from any, req *incomingRequest, result any, e } else if err != nil { err = fmt.Errorf("%w: %q notification failed: %v", ErrInternal, req.Method, err) } - if err != nil { - // TODO: can/should we do anything with this error beyond writing it to the event log? - // (Is this the right label to attach to the log?) - } + } + if err != nil { + // TODO: can/should we do anything with this error beyond writing it to the event log? + // (Is this the right label to attach to the log?) } // Cancel the request to free any associated resources. req.cancel() c.updateInFlight(func(s *inFlightState) { if s.incoming == 0 { - panic("jsonrpc2_v2: processResult called when incoming count is already zero") + panic("jsonrpc2: processResult called when incoming count is already zero") } s.incoming-- }) @@ -728,9 +760,23 @@ func (c *Connection) processResult(from any, req *incomingRequest, result any, e // write is used by all things that write outgoing messages, including replies. // it makes sure that writes are atomic func (c *Connection) write(ctx context.Context, msg Message) error { - writer := <-c.writer - defer func() { c.writer <- writer }() - err := writer.Write(ctx, msg) + var err error + // Fail writes immediately if the connection is shutting down. + // + // TODO(rfindley): should we allow cancellation notifications through? It + // could be the case that writes can still succeed. + c.updateInFlight(func(s *inFlightState) { + err = s.shuttingDown(ErrServerClosing) + }) + if err == nil { + err = c.writer.Write(ctx, msg) + } + + // For rejected requests, we don't set the writeErr (which would break the + // connection). They can just be returned to the caller. + if errors.Is(err, ErrRejected) { + return err + } if err != nil && ctx.Err() == nil { // The call to Write failed, and since ctx.Err() is nil we can't attribute diff --git a/internal/jsonrpc2/frame.go b/internal/jsonrpc2/frame.go index d5bbe75b..46fcc9db 100644 --- a/internal/jsonrpc2/frame.go +++ b/internal/jsonrpc2/frame.go @@ -12,12 +12,14 @@ import ( "io" "strconv" "strings" + "sync" ) // Reader abstracts the transport mechanics from the JSON RPC protocol. // A Conn reads messages from the reader it was provided on construction, // and assumes that each call to Read fully transfers a single message, // or returns an error. +// // A reader is not safe for concurrent use, it is expected it will be used by // a single Conn in a safe manner. type Reader interface { @@ -29,8 +31,9 @@ type Reader interface { // A Conn writes messages using the writer it was provided on construction, // and assumes that each call to Write fully transfers a single message, // or returns an error. -// A writer is not safe for concurrent use, it is expected it will be used by -// a single Conn in a safe manner. +// +// A writer must be safe for concurrent use, as writes may occur concurrently +// in practice: libraries may make calls or respond to requests asynchronously. type Writer interface { // Write sends a message to the stream. Write(context.Context, Message) error @@ -62,7 +65,10 @@ func RawFramer() Framer { return rawFramer{} } type rawFramer struct{} type rawReader struct{ in *json.Decoder } -type rawWriter struct{ out io.Writer } +type rawWriter struct { + mu sync.Mutex + out io.Writer +} func (rawFramer) Reader(rw io.Reader) Reader { return &rawReader{in: json.NewDecoder(rw)} @@ -92,10 +98,14 @@ func (w *rawWriter) Write(ctx context.Context, msg Message) error { return ctx.Err() default: } + data, err := EncodeMessage(msg) if err != nil { return fmt.Errorf("marshaling message: %v", err) } + + w.mu.Lock() + defer w.mu.Unlock() _, err = w.out.Write(data) return err } @@ -107,7 +117,10 @@ func HeaderFramer() Framer { return headerFramer{} } type headerFramer struct{} type headerReader struct{ in *bufio.Reader } -type headerWriter struct{ out io.Writer } +type headerWriter struct { + mu sync.Mutex + out io.Writer +} func (headerFramer) Reader(rw io.Reader) Reader { return &headerReader{in: bufio.NewReader(rw)} @@ -180,6 +193,9 @@ func (w *headerWriter) Write(ctx context.Context, msg Message) error { return ctx.Err() default: } + w.mu.Lock() + defer w.mu.Unlock() + data, err := EncodeMessage(msg) if err != nil { return fmt.Errorf("marshaling message: %v", err) diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go index b9c320c8..234e6ee3 100644 --- a/internal/jsonrpc2/jsonrpc2.go +++ b/internal/jsonrpc2/jsonrpc2.go @@ -22,13 +22,6 @@ var ( // If a Handler returns ErrNotHandled, the server replies with // ErrMethodNotFound. ErrNotHandled = errors.New("JSON RPC not handled") - - // ErrAsyncResponse is returned from a handler to indicate it will generate a - // response asynchronously. - // - // ErrAsyncResponse must not be returned for notifications, - // which do not receive responses. - ErrAsyncResponse = errors.New("JSON RPC asynchronous response") ) // Preempter handles messages on a connection before they are queued to the main diff --git a/internal/jsonrpc2/jsonrpc2_test.go b/internal/jsonrpc2/jsonrpc2_test.go index 35f4e7f9..8c79300c 100644 --- a/internal/jsonrpc2/jsonrpc2_test.go +++ b/internal/jsonrpc2/jsonrpc2_test.go @@ -11,6 +11,7 @@ import ( "path" "reflect" "testing" + "time" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) @@ -143,6 +144,8 @@ func testConnection(t *testing.T, framer jsonrpc2.Framer) { t.Run(test.Name(), func(t *testing.T) { client, err := jsonrpc2.Dial(ctx, listener.Dialer(), binder{framer, func(h *handler) { + // Sleep a little to a void a race with setting conn.writer in jsonrpc2.bindConnection. + time.Sleep(50 * time.Millisecond) defer h.conn.Close() test.Invoke(t, ctx, h) if call, ok := test.(*call); ok { @@ -368,16 +371,14 @@ func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (any, error if err := json.Unmarshal(req.Params, &name); err != nil { return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) } + jsonrpc2.Async(ctx) waitFor := h.waiter(name) - go func() { - select { - case <-waitFor: - h.conn.Respond(req.ID, true, nil) - case <-ctx.Done(): - h.conn.Respond(req.ID, nil, ctx.Err()) - } - }() - return nil, jsonrpc2.ErrAsyncResponse + select { + case <-waitFor: + return true, nil + case <-ctx.Done(): + return nil, ctx.Err() + } default: return nil, jsonrpc2.ErrNotHandled } diff --git a/internal/jsonrpc2/messages.go b/internal/jsonrpc2/messages.go index 03371b91..9c0d5d69 100644 --- a/internal/jsonrpc2/messages.go +++ b/internal/jsonrpc2/messages.go @@ -19,7 +19,7 @@ type ID struct { // MakeID coerces the given Go value to an ID. The value is assumed to be the // default JSON marshaling of a Request identifier -- nil, float64, or string. // -// Returns an error if the value type was a valid Request ID type. +// Returns an error if the value type was not a valid Request ID type. // // TODO: ID can't be a json.Marshaler/Unmarshaler, because we want to omitzero. // Simplify this package by making ID json serializable once we can rely on @@ -56,6 +56,9 @@ type Request struct { Method string // Params is either a struct or an array with the parameters of the method. Params json.RawMessage + // Extra is additional information that does not appear on the wire. It can be + // used to pass information from the application to the underlying transport. + Extra any } // Response is a Message used as a reply to a call Request. @@ -67,6 +70,9 @@ type Response struct { Error error // id of the request this is a response to. ID ID + // Extra is additional information that does not appear on the wire. It can be + // used to pass information from the underlying transport to the application. + Extra any } // StringID creates a new string request identifier. diff --git a/internal/jsonrpc2/wire.go b/internal/jsonrpc2/wire.go index 309b8002..8be2872e 100644 --- a/internal/jsonrpc2/wire.go +++ b/internal/jsonrpc2/wire.go @@ -13,30 +13,41 @@ import ( var ( // ErrParse is used when invalid JSON was received by the server. - ErrParse = NewError(-32700, "JSON RPC parse error") + ErrParse = NewError(-32700, "parse error") // ErrInvalidRequest is used when the JSON sent is not a valid Request object. - ErrInvalidRequest = NewError(-32600, "JSON RPC invalid request") + ErrInvalidRequest = NewError(-32600, "invalid request") // ErrMethodNotFound should be returned by the handler when the method does // not exist / is not available. - ErrMethodNotFound = NewError(-32601, "JSON RPC method not found") + ErrMethodNotFound = NewError(-32601, "method not found") // ErrInvalidParams should be returned by the handler when method // parameter(s) were invalid. - ErrInvalidParams = NewError(-32602, "JSON RPC invalid params") + ErrInvalidParams = NewError(-32602, "invalid params") // ErrInternal indicates a failure to process a call correctly - ErrInternal = NewError(-32603, "JSON RPC internal error") + ErrInternal = NewError(-32603, "internal error") // The following errors are not part of the json specification, but // compliant extensions specific to this implementation. // ErrServerOverloaded is returned when a message was refused due to a // server being temporarily unable to accept any new messages. - ErrServerOverloaded = NewError(-32000, "JSON RPC overloaded") + ErrServerOverloaded = NewError(-32000, "overloaded") // ErrUnknown should be used for all non coded errors. - ErrUnknown = NewError(-32001, "JSON RPC unknown error") + ErrUnknown = NewError(-32001, "unknown error") // ErrServerClosing is returned for calls that arrive while the server is closing. - ErrServerClosing = NewError(-32004, "JSON RPC server is closing") + ErrServerClosing = NewError(-32004, "server is closing") // ErrClientClosing is a dummy error returned for calls initiated while the client is closing. - ErrClientClosing = NewError(-32003, "JSON RPC client is closing") + ErrClientClosing = NewError(-32003, "client is closing") + + // The following errors have special semantics for MCP transports + + // ErrRejected may be wrapped to return errors from calls to Writer.Write + // that signal that the request was rejected by the transport layer as + // invalid. + // + // Such failures do not indicate that the connection is broken, but rather + // should be returned to the caller to indicate that the specific request is + // invalid in the current context. + ErrRejected = NewError(-32004, "rejected by transport") ) const wireVersion = "2.0" diff --git a/internal/oauthex/auth_meta.go b/internal/oauthex/auth_meta.go new file mode 100644 index 00000000..1f075f8a --- /dev/null +++ b/internal/oauthex/auth_meta.go @@ -0,0 +1,145 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Authorization Server Metadata. +// See https://www.rfc-editor.org/rfc/rfc8414.html. + +package oauthex + +import ( + "context" + "errors" + "fmt" + "net/http" +) + +// AuthServerMeta represents the metadata for an OAuth 2.0 authorization server, +// as defined in [RFC 8414]. +// +// Not supported: +// - signed metadata +// +// [RFC 8414]: https://tools.ietf.org/html/rfc8414) +type AuthServerMeta struct { + // GENERATED BY GEMINI 2.5. + + // Issuer is the REQUIRED URL identifying the authorization server. + Issuer string `json:"issuer"` + + // AuthorizationEndpoint is the REQUIRED URL of the server's OAuth 2.0 authorization endpoint. + AuthorizationEndpoint string `json:"authorization_endpoint"` + + // TokenEndpoint is the REQUIRED URL of the server's OAuth 2.0 token endpoint. + TokenEndpoint string `json:"token_endpoint"` + + // JWKSURI is the REQUIRED URL of the server's JSON Web Key Set [JWK] document. + JWKSURI string `json:"jwks_uri"` + + // RegistrationEndpoint is the RECOMMENDED URL of the server's OAuth 2.0 Dynamic Client Registration endpoint. + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + + // ScopesSupported is a RECOMMENDED JSON array of strings containing a list of the OAuth 2.0 + // "scope" values that this server supports. + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // ResponseTypesSupported is a REQUIRED JSON array of strings containing a list of the OAuth 2.0 + // "response_type" values that this server supports. + ResponseTypesSupported []string `json:"response_types_supported"` + + // ResponseModesSupported is a RECOMMENDED JSON array of strings containing a list of the OAuth 2.0 + // "response_mode" values that this server supports. + ResponseModesSupported []string `json:"response_modes_supported,omitempty"` + + // GrantTypesSupported is a RECOMMENDED JSON array of strings containing a list of the OAuth 2.0 + // grant type values that this server supports. + GrantTypesSupported []string `json:"grant_types_supported,omitempty"` + + // TokenEndpointAuthMethodsSupported is a RECOMMENDED JSON array of strings containing a list of + // client authentication methods supported by this token endpoint. + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"` + + // TokenEndpointAuthSigningAlgValuesSupported is a RECOMMENDED JSON array of strings containing + // a list of the JWS signing algorithms ("alg" values) supported by the token endpoint for + // the signature on the JWT used to authenticate the client. + TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"` + + // ServiceDocumentation is a RECOMMENDED URL of a page containing human-readable documentation + // for the service. + ServiceDocumentation string `json:"service_documentation,omitempty"` + + // UILocalesSupported is a RECOMMENDED JSON array of strings representing supported + // BCP47 [RFC5646] language tag values for display in the user interface. + UILocalesSupported []string `json:"ui_locales_supported,omitempty"` + + // OpPolicyURI is a RECOMMENDED URL that the server provides to the person registering + // the client to read about the server's operator policies. + OpPolicyURI string `json:"op_policy_uri,omitempty"` + + // OpTOSURI is a RECOMMENDED URL that the server provides to the person registering the + // client to read about the server's terms of service. + OpTOSURI string `json:"op_tos_uri,omitempty"` + + // RevocationEndpoint is a RECOMMENDED URL of the server's OAuth 2.0 revocation endpoint. + RevocationEndpoint string `json:"revocation_endpoint,omitempty"` + + // RevocationEndpointAuthMethodsSupported is a RECOMMENDED JSON array of strings containing + // a list of client authentication methods supported by this revocation endpoint. + RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported,omitempty"` + + // RevocationEndpointAuthSigningAlgValuesSupported is a RECOMMENDED JSON array of strings + // containing a list of the JWS signing algorithms ("alg" values) supported by the revocation + // endpoint for the signature on the JWT used to authenticate the client. + RevocationEndpointAuthSigningAlgValuesSupported []string `json:"revocation_endpoint_auth_signing_alg_values_supported,omitempty"` + + // IntrospectionEndpoint is a RECOMMENDED URL of the server's OAuth 2.0 introspection endpoint. + IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"` + + // IntrospectionEndpointAuthMethodsSupported is a RECOMMENDED JSON array of strings containing + // a list of client authentication methods supported by this introspection endpoint. + IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported,omitempty"` + + // IntrospectionEndpointAuthSigningAlgValuesSupported is a RECOMMENDED JSON array of strings + // containing a list of the JWS signing algorithms ("alg" values) supported by the introspection + // endpoint for the signature on the JWT used to authenticate the client. + IntrospectionEndpointAuthSigningAlgValuesSupported []string `json:"introspection_endpoint_auth_signing_alg_values_supported,omitempty"` + + // CodeChallengeMethodsSupported is a RECOMMENDED JSON array of strings containing a list of + // PKCE code challenge methods supported by this authorization server. + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported,omitempty"` +} + +var wellKnownPaths = []string{ + "/.well-known/oauth-authorization-server", + "/.well-known/openid-configuration", +} + +// GetAuthServerMeta issues a GET request to retrieve authorization server metadata +// from an OAuth authorization server with the given issuerURL. +// +// It follows [RFC 8414]: +// - The well-known paths specified there are inserted into the URL's path, one at time. +// The first to succeed is used. +// - The Issuer field is checked against issuerURL. +// +// [RFC 8414]: https://tools.ietf.org/html/rfc8414 +func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (*AuthServerMeta, error) { + var errs []error + for _, p := range wellKnownPaths { + u, err := prependToPath(issuerURL, p) + if err != nil { + // issuerURL is bad; no point in continuing. + return nil, err + } + asm, err := getJSON[AuthServerMeta](ctx, c, u, 1<<20) + if err == nil { + if asm.Issuer != issuerURL { // section 3.3 + // Security violation; don't keep trying. + return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuerURL) + } + return asm, nil + } + errs = append(errs, err) + } + return nil, fmt.Errorf("failed to get auth server metadata from %q: %w", issuerURL, errors.Join(errs...)) +} diff --git a/internal/oauthex/auth_meta_test.go b/internal/oauthex/auth_meta_test.go new file mode 100644 index 00000000..b83402f2 --- /dev/null +++ b/internal/oauthex/auth_meta_test.go @@ -0,0 +1,28 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package oauthex + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestAuthMetaParse(t *testing.T) { + // Verify that we parse Google's auth server metadata. + data, err := os.ReadFile(filepath.FromSlash("testdata/google-auth-meta.json")) + if err != nil { + t.Fatal(err) + } + var a AuthServerMeta + if err := json.Unmarshal(data, &a); err != nil { + t.Fatal(err) + } + // Spot check. + if g, w := a.Issuer, "https://accounts.google.com"; g != w { + t.Errorf("got %q, want %q", g, w) + } +} diff --git a/internal/oauthex/oauth2.go b/internal/oauthex/oauth2.go new file mode 100644 index 00000000..de164499 --- /dev/null +++ b/internal/oauthex/oauth2.go @@ -0,0 +1,67 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package oauthex implements extensions to OAuth2. +package oauthex + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" +) + +// prependToPath prepends pre to the path of urlStr. +// When pre is the well-known path, this is the algorithm specified in both RFC 9728 +// section 3.1 and RFC 8414 section 3.1. +func prependToPath(urlStr, pre string) (string, error) { + u, err := url.Parse(urlStr) + if err != nil { + return "", err + } + p := "/" + strings.Trim(pre, "/") + if u.Path != "" { + p += "/" + } + + u.Path = p + strings.TrimLeft(u.Path, "/") + return u.String(), nil +} + +// getJSON retrieves JSON and unmarshals JSON from the URL, as specified in both +// RFC 9728 and RFC 8414. +// It will not read more than limit bytes from the body. +func getJSON[T any](ctx context.Context, c *http.Client, url string, limit int64) (*T, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + if c == nil { + c = http.DefaultClient + } + res, err := c.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + // Specs require a 200. + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("bad status %s", res.Status) + } + // Specs require application/json. + if ct := res.Header.Get("Content-Type"); ct != "application/json" { + return nil, fmt.Errorf("bad content type %q", ct) + } + + var t T + dec := json.NewDecoder(io.LimitReader(res.Body, limit)) + if err := dec.Decode(&t); err != nil { + return nil, err + } + return &t, nil +} diff --git a/internal/oauthex/oauth2_test.go b/internal/oauthex/oauth2_test.go new file mode 100644 index 00000000..9c3da156 --- /dev/null +++ b/internal/oauthex/oauth2_test.go @@ -0,0 +1,270 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package oauthex + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func TestSplitChallenges(t *testing.T) { + tests := []struct { + name string + input string + want []string + }{ + { + name: "single challenge no params", + input: `Basic`, + want: []string{`Basic`}, + }, + { + name: "single challenge with params", + input: `Bearer realm="example.com", error="invalid_token"`, + want: []string{`Bearer realm="example.com", error="invalid_token"`}, + }, + { + name: "single challenge with comma in quoted string", + input: `Bearer realm="example, with comma"`, + want: []string{`Bearer realm="example, with comma"`}, + }, + { + name: "two challenges", + input: `Basic, Bearer realm="example"`, + want: []string{`Basic`, ` Bearer realm="example"`}, + }, + { + name: "multiple challenges complex", + input: `Newauth realm="apps", Basic, Bearer realm="example.com", error="invalid_token"`, + want: []string{`Newauth realm="apps"`, ` Basic`, ` Bearer realm="example.com", error="invalid_token"`}, + }, + { + name: "challenge with escaped quote", + input: `Bearer realm="example \"quoted\""`, + want: []string{`Bearer realm="example \"quoted\""`}, + }, + { + name: "empty input", + input: "", + want: []string{""}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := splitChallenges(tt.input) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("splitChallenges() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSplitChallengesError(t *testing.T) { + if _, err := splitChallenges(`"Bearer"`); err == nil { + t.Fatal("got nil, want error") + } +} + +func TestParseSingleChallenge(t *testing.T) { + tests := []struct { + name string + input string + want challenge + wantErr bool + }{ + { + name: "scheme only", + input: "Basic", + want: challenge{ + Scheme: "basic", + }, + wantErr: false, + }, + { + name: "scheme with one quoted param", + input: `Bearer realm="example.com"`, + want: challenge{ + Scheme: "bearer", + Params: map[string]string{"realm": "example.com"}, + }, + wantErr: false, + }, + { + name: "scheme with one unquoted param", + input: `Bearer realm=example.com`, + want: challenge{ + Scheme: "bearer", + Params: map[string]string{"realm": "example.com"}, + }, + wantErr: false, + }, + { + name: "scheme with multiple params", + input: `Bearer realm="example", error="invalid_token", error_description="The token expired"`, + want: challenge{ + Scheme: "bearer", + Params: map[string]string{ + "realm": "example", + "error": "invalid_token", + "error_description": "The token expired", + }, + }, + wantErr: false, + }, + { + name: "scheme with multiple unquoted params", + input: `Bearer realm=example, error=invalid_token, error_description=The token expired`, + want: challenge{ + Scheme: "bearer", + Params: map[string]string{ + "realm": "example", + "error": "invalid_token", + "error_description": "The token expired", + }, + }, + wantErr: false, + }, + { + name: "case-insensitive scheme and keys", + input: `BEARER ReAlM="example"`, + want: challenge{ + Scheme: "bearer", + Params: map[string]string{"realm": "example"}, + }, + wantErr: false, + }, + { + name: "param with escaped quote", + input: `Bearer realm="example \"foo\" bar"`, + want: challenge{ + Scheme: "bearer", + Params: map[string]string{"realm": `example "foo" bar`}, + }, + wantErr: false, + }, + { + name: "param without quotes (token)", + input: "Bearer realm=example.com", + want: challenge{ + Scheme: "bearer", + Params: map[string]string{"realm": "example.com"}, + }, + wantErr: false, + }, + { + name: "malformed param - no value", + input: "Bearer realm=", + wantErr: true, + }, + { + name: "malformed param - unterminated quote", + input: `Bearer realm="example`, + wantErr: true, + }, + { + name: "malformed param - missing comma", + input: `Bearer realm="a" error="b"`, + wantErr: true, + }, + { + name: "malformed param - initial equal", + input: `Bearer ="a"`, + wantErr: true, + }, + { + name: "empty input", + input: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseSingleChallenge(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("parseSingleChallenge() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseSingleChallenge() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetProtectedResourceMetadata(t *testing.T) { + ctx := context.Background() + t.Run("FromHeader", func(t *testing.T) { + h := &fakeResourceHandler{serveWWWAuthenticate: true} + server := httptest.NewTLSServer(h) + h.installHandlers(server.URL) + client := server.Client() + res, err := client.Get(server.URL + "/resource") + if err != nil { + t.Fatal(err) + } + if res.StatusCode != http.StatusUnauthorized { + t.Fatal("want unauth") + } + prm, err := GetProtectedResourceMetadataFromHeader(ctx, res.Header, client) + if err != nil { + t.Fatal(err) + } + if prm == nil { + t.Fatal("nil prm") + } + }) + t.Run("FromID", func(t *testing.T) { + h := &fakeResourceHandler{serveWWWAuthenticate: false} + server := httptest.NewTLSServer(h) + h.installHandlers(server.URL) + client := server.Client() + prm, err := GetProtectedResourceMetadataFromID(ctx, server.URL, client) + if err != nil { + t.Fatal(err) + } + if prm == nil { + t.Fatal("nil prm") + } + }) +} + +type fakeResourceHandler struct { + http.ServeMux + serveWWWAuthenticate bool +} + +func (h *fakeResourceHandler) installHandlers(serverURL string) { + path := "/.well-known/oauth-protected-resource" + url := serverURL + path + h.Handle("GET /resource", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if h.serveWWWAuthenticate { + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, url)) + } + w.WriteHeader(http.StatusUnauthorized) + })) + h.Handle("GET "+path, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + // If there is a WWW-Authenticate header, the resource field is the value of that header. + // If not, it's the server URL without the "/.well-known/..." part. + resource := serverURL + if h.serveWWWAuthenticate { + resource = url + } + prm := &ProtectedResourceMetadata{Resource: resource} + if err := json.NewEncoder(w).Encode(prm); err != nil { + panic(err) + } + })) +} diff --git a/internal/oauthex/resource_meta.go b/internal/oauthex/resource_meta.go new file mode 100644 index 00000000..71d52cde --- /dev/null +++ b/internal/oauthex/resource_meta.go @@ -0,0 +1,358 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Protected Resource Metadata. +// See https://www.rfc-editor.org/rfc/rfc9728.html. + +package oauthex + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "path" + "strings" + "unicode" + + "github.com/modelcontextprotocol/go-sdk/internal/util" +) + +const defaultProtectedResourceMetadataURI = "/.well-known/oauth-protected-resource" + +// ProtectedResourceMetadata is the metadata for an OAuth 2.0 protected resource, +// as defined in section 2 of https://www.rfc-editor.org/rfc/rfc9728.html. +// +// The following features are not supported: +// - additional keys (§2, last sentence) +// - human-readable metadata (§2.1) +// - signed metadata (§2.2) +type ProtectedResourceMetadata struct { + // GENERATED BY GEMINI 2.5. + + // Resource (resource) is the protected resource's resource identifier. + // Required. + Resource string `json:"resource"` + + // AuthorizationServers (authorization_servers) is an optional slice containing a list of + // OAuth authorization server issuer identifiers (as defined in RFC 8414) that can be + // used with this protected resource. + AuthorizationServers []string `json:"authorization_servers,omitempty"` + + // JWKSURI (jwks_uri) is an optional URL of the protected resource's JSON Web Key (JWK) Set + // document. This contains public keys belonging to the protected resource, such as + // signing key(s) that the resource server uses to sign resource responses. + JWKSURI string `json:"jwks_uri,omitempty"` + + // ScopesSupported (scopes_supported) is a recommended slice containing a list of scope + // values (as defined in RFC 6749) used in authorization requests to request access + // to this protected resource. + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // BearerMethodsSupported (bearer_methods_supported) is an optional slice containing + // a list of the supported methods of sending an OAuth 2.0 bearer token to the + // protected resource. Defined values are "header", "body", and "query". + BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"` + + // ResourceSigningAlgValuesSupported (resource_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms (alg values) supported by the protected + // resource for signing resource responses. + ResourceSigningAlgValuesSupported []string `json:"resource_signing_alg_values_supported,omitempty"` + + // ResourceName (resource_name) is a human-readable name of the protected resource + // intended for display to the end user. It is RECOMMENDED that this field be included. + // This value may be internationalized. + ResourceName string `json:"resource_name,omitempty"` + + // ResourceDocumentation (resource_documentation) is an optional URL of a page containing + // human-readable information for developers using the protected resource. + // This value may be internationalized. + ResourceDocumentation string `json:"resource_documentation,omitempty"` + + // ResourcePolicyURI (resource_policy_uri) is an optional URL of a page containing + // human-readable policy information on how a client can use the data provided. + // This value may be internationalized. + ResourcePolicyURI string `json:"resource_policy_uri,omitempty"` + + // ResourceTOSURI (resource_tos_uri) is an optional URL of a page containing the protected + // resource's human-readable terms of service. This value may be internationalized. + ResourceTOSURI string `json:"resource_tos_uri,omitempty"` + + // TLSClientCertificateBoundAccessTokens (tls_client_certificate_bound_access_tokens) is an + // optional boolean indicating support for mutual-TLS client certificate-bound + // access tokens (RFC 8705). Defaults to false if omitted. + TLSClientCertificateBoundAccessTokens bool `json:"tls_client_certificate_bound_access_tokens,omitempty"` + + // AuthorizationDetailsTypesSupported (authorization_details_types_supported) is an optional + // slice of 'type' values supported by the resource server for the + // 'authorization_details' parameter (RFC 9396). + AuthorizationDetailsTypesSupported []string `json:"authorization_details_types_supported,omitempty"` + + // DPOPSigningAlgValuesSupported (dpop_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms supported by the resource server for validating + // DPoP proof JWTs (RFC 9449). + DPOPSigningAlgValuesSupported []string `json:"dpop_signing_alg_values_supported,omitempty"` + + // DPOPBoundAccessTokensRequired (dpop_bound_access_tokens_required) is an optional boolean + // specifying whether the protected resource always requires the use of DPoP-bound + // access tokens (RFC 9449). Defaults to false if omitted. + DPOPBoundAccessTokensRequired bool `json:"dpop_bound_access_tokens_required,omitempty"` + + // SignedMetadata (signed_metadata) is an optional JWT containing metadata parameters + // about the protected resource as claims. If present, these values take precedence + // over values conveyed in plain JSON. + // TODO:implement. + // Note that §2.2 says it's okay to ignore this. + // SignedMetadata string `json:"signed_metadata,omitempty"` +} + +// GetProtectedResourceMetadataFromID issues a GET request to retrieve protected resource +// metadata from a resource server by its ID. +// The resource ID is an HTTPS URL, typically with a host:port and possibly a path. +// For example: +// +// https://example.com/server +// +// This function, following the spec (§3), inserts the default well-known path into the +// URL. In our example, the result would be +// +// https://example.com/.well-known/oauth-protected-resource/server +// +// It then retrieves the metadata at that location using the given client (or the +// default client if nil) and validates its resource field against resourceID. +func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, c *http.Client) (_ *ProtectedResourceMetadata, err error) { + defer util.Wrapf(&err, "GetProtectedResourceMetadataFromID(%q)", resourceID) + + u, err := url.Parse(resourceID) + if err != nil { + return nil, err + } + // Insert well-known URI into URL. + u.Path = path.Join(defaultProtectedResourceMetadataURI, u.Path) + return getPRM(ctx, u.String(), c, resourceID) +} + +// GetProtectedResourceMetadataFromHeader retrieves protected resource metadata +// using information in the given header, using the given client (or the default +// client if nil). +// It issues a GET request to a URL discovered by parsing the WWW-Authenticate headers in the given request, +// It then validates the resource field of the resulting metadata against the given URL. +// If there is no URL in the request, it returns nil, nil. +func GetProtectedResourceMetadataFromHeader(ctx context.Context, header http.Header, c *http.Client) (_ *ProtectedResourceMetadata, err error) { + defer util.Wrapf(&err, "GetProtectedResourceMetadataFromHeader") + headers := header[http.CanonicalHeaderKey("WWW-Authenticate")] + if len(headers) == 0 { + return nil, nil + } + cs, err := parseWWWAuthenticate(headers) + if err != nil { + return nil, err + } + url := resourceMetadataURL(cs) + if url == "" { + return nil, nil + } + return getPRM(ctx, url, c, url) +} + +// getPRM makes a GET request to the given URL, and validates the response. +// As part of the validation, it compares the returned resource field to wantResource. +func getPRM(ctx context.Context, url string, c *http.Client, wantResource string) (*ProtectedResourceMetadata, error) { + if !strings.HasPrefix(strings.ToUpper(url), "HTTPS://") { + return nil, fmt.Errorf("resource URL %q does not use HTTPS", url) + } + prm, err := getJSON[ProtectedResourceMetadata](ctx, c, url, 1<<20) + if err != nil { + return nil, err + } + // Validate the Resource field to thwart impersonation attacks (section 3.3). + if prm.Resource != wantResource { + return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, wantResource) + } + return prm, nil +} + +// challenge represents a single authentication challenge from a WWW-Authenticate header. +// As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. +type challenge struct { + // GENERATED BY GEMINI 2.5. + // + // Scheme is the authentication scheme (e.g., "Bearer", "Basic"). + // It is case-insensitive. A parsed value will always be lower-case. + Scheme string + // Params is a map of authentication parameters. + // Keys are case-insensitive. Parsed keys are always lower-case. + Params map[string]string +} + +// resourceMetadataURL returns a resource metadata URL from the given challenges, +// or the empty string if there is none. +func resourceMetadataURL(cs []challenge) string { + for _, c := range cs { + if u := c.Params["resource_metadata"]; u != "" { + return u + } + } + return "" +} + +// parseWWWAuthenticate parses a WWW-Authenticate header string. +// The header format is defined in RFC 9110, Section 11.6.1, and can contain +// one or more challenges, separated by commas. +// It returns a slice of challenges or an error if one of the headers is malformed. +func parseWWWAuthenticate(headers []string) ([]challenge, error) { + // GENERATED BY GEMINI 2.5 (human-tweaked) + var challenges []challenge + for _, h := range headers { + challengeStrings, err := splitChallenges(h) + if err != nil { + return nil, err + } + for _, cs := range challengeStrings { + if strings.TrimSpace(cs) == "" { + continue + } + challenge, err := parseSingleChallenge(cs) + if err != nil { + return nil, fmt.Errorf("failed to parse challenge %q: %w", cs, err) + } + challenges = append(challenges, challenge) + } + } + return challenges, nil +} + +// splitChallenges splits a header value containing one or more challenges. +// It correctly handles commas within quoted strings and distinguishes between +// commas separating auth-params and commas separating challenges. +func splitChallenges(header string) ([]string, error) { + // GENERATED BY GEMINI 2.5. + var challenges []string + inQuotes := false + start := 0 + for i, r := range header { + if r == '"' { + if i > 0 && header[i-1] != '\\' { + inQuotes = !inQuotes + } else if i == 0 { + // A challenge begins with an auth-scheme, which is a token, which cannot contain + // a quote. + return nil, errors.New(`challenge begins with '"'`) + } + } else if r == ',' && !inQuotes { + // This is a potential challenge separator. + // A new challenge does not start with `key=value`. + // We check if the part after the comma looks like a parameter. + lookahead := strings.TrimSpace(header[i+1:]) + eqPos := strings.Index(lookahead, "=") + + isParam := false + if eqPos > 0 { + // Check if the part before '=' is a single token (no spaces). + token := lookahead[:eqPos] + if strings.IndexFunc(token, unicode.IsSpace) == -1 { + isParam = true + } + } + + if !isParam { + // The part after the comma does not look like a parameter, + // so this comma separates challenges. + challenges = append(challenges, header[start:i]) + start = i + 1 + } + } + } + // Add the last (or only) challenge to the list. + challenges = append(challenges, header[start:]) + return challenges, nil +} + +// parseSingleChallenge parses a string containing exactly one challenge. +// challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ] +func parseSingleChallenge(s string) (challenge, error) { + // GENERATED BY GEMINI 2.5, human-tweaked. + s = strings.TrimSpace(s) + if s == "" { + return challenge{}, errors.New("empty challenge string") + } + + scheme, paramsStr, found := strings.Cut(s, " ") + c := challenge{Scheme: strings.ToLower(scheme)} + if !found { + return c, nil + } + + params := make(map[string]string) + + // Parse the key-value parameters. + for paramsStr != "" { + // Find the end of the parameter key. + keyEnd := strings.Index(paramsStr, "=") + if keyEnd <= 0 { + return challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) + } + key := strings.TrimSpace(paramsStr[:keyEnd]) + + // Move the string past the key and the '='. + paramsStr = strings.TrimSpace(paramsStr[keyEnd+1:]) + + var value string + if strings.HasPrefix(paramsStr, "\"") { + // The value is a quoted string. + paramsStr = paramsStr[1:] // Consume the opening quote. + var valBuilder strings.Builder + i := 0 + for ; i < len(paramsStr); i++ { + // Handle escaped characters. + if paramsStr[i] == '\\' && i+1 < len(paramsStr) { + valBuilder.WriteByte(paramsStr[i+1]) + i++ // We've consumed two characters. + } else if paramsStr[i] == '"' { + // End of the quoted string. + break + } else { + valBuilder.WriteByte(paramsStr[i]) + } + } + + // A quoted string must be terminated. + if i == len(paramsStr) { + return challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") + } + + value = valBuilder.String() + // Move the string past the value and the closing quote. + paramsStr = strings.TrimSpace(paramsStr[i+1:]) + } else { + // The value is a token. It ends at the next comma or the end of the string. + commaPos := strings.Index(paramsStr, ",") + if commaPos == -1 { + value = paramsStr + paramsStr = "" + } else { + value = strings.TrimSpace(paramsStr[:commaPos]) + paramsStr = strings.TrimSpace(paramsStr[commaPos:]) // Keep comma for next check + } + } + if value == "" { + return challenge{}, fmt.Errorf("no value for auth param %q", key) + } + + // Per RFC 9110, parameter keys are case-insensitive. + params[strings.ToLower(key)] = value + + // If there is a comma, consume it and continue to the next parameter. + if strings.HasPrefix(paramsStr, ",") { + paramsStr = strings.TrimSpace(paramsStr[1:]) + } else if paramsStr != "" { + // If there's content but it's not a new parameter, the format is wrong. + return challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) + } + } + + // Per RFC 9110, the scheme is case-insensitive. + return challenge{Scheme: strings.ToLower(scheme), Params: params}, nil +} diff --git a/internal/oauthex/testdata/google-auth-meta.json b/internal/oauthex/testdata/google-auth-meta.json new file mode 100644 index 00000000..258b7534 --- /dev/null +++ b/internal/oauthex/testdata/google-auth-meta.json @@ -0,0 +1,57 @@ +{ + "issuer": "https://accounts.google.com", + "authorization_endpoint": "https://accounts.google.com/o/oauth2/v2/auth", + "device_authorization_endpoint": "https://oauth2.googleapis.com/device/code", + "token_endpoint": "https://oauth2.googleapis.com/token", + "userinfo_endpoint": "https://openidconnect.googleapis.com/v1/userinfo", + "revocation_endpoint": "https://oauth2.googleapis.com/revoke", + "jwks_uri": "https://www.googleapis.com/oauth2/v3/certs", + "response_types_supported": [ + "code", + "token", + "id_token", + "code token", + "code id_token", + "token id_token", + "code token id_token", + "none" + ], + "subject_types_supported": [ + "public" + ], + "id_token_signing_alg_values_supported": [ + "RS256" + ], + "scopes_supported": [ + "openid", + "email", + "profile" + ], + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic" + ], + "claims_supported": [ + "aud", + "email", + "email_verified", + "exp", + "family_name", + "given_name", + "iat", + "iss", + "name", + "picture", + "sub" + ], + "code_challenge_methods_supported": [ + "plain", + "S256" + ], + "grant_types_supported": [ + "authorization_code", + "refresh_token", + "urn:ietf:params:oauth:grant-type:device_code", + "urn:ietf:params:oauth:grant-type:jwt-bearer" + ] +} diff --git a/internal/readme/Makefile b/internal/readme/Makefile deleted file mode 100644 index 5e484184..00000000 --- a/internal/readme/Makefile +++ /dev/null @@ -1,19 +0,0 @@ -# This makefile builds ../README.md from the files in this directory. - -OUTFILE=../../README.md - -$(OUTFILE): build README.src.md - go run golang.org/x/example/internal/cmd/weave@latest README.src.md > $(OUTFILE) - -# Compile all the code used in the README. -build: $(wildcard */*.go) - go build -o /tmp/mcp-readme/ ./... - -# Preview the README on GitHub. -# $HOME/markdown must be a github repo. -# Visit https://github.com/$HOME/markdown to see the result. -preview: $(OUTFILE) - cp $(OUTFILE) $$HOME/markdown/ - (cd $$HOME/markdown/ && git commit -m . README.md && git push) - -.PHONY: build preview diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index 629629a4..10ce9b67 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -1,61 +1,75 @@ -# MCP Go SDK +# MCP Go SDK v0.6.0 -[![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) +[![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/modelcontextprotocol/go-sdk) -This repository contains an unreleased implementation of the official Go -software development kit (SDK) for the Model Context Protocol (MCP). +***BREAKING CHANGES*** -**WARNING**: The SDK should be considered unreleased, and is currently unstable -and subject to breaking changes. Please test it out and file bug reports or API -proposals, but don't use it in real projects. See the issue tracker for known -issues and missing features. We aim to release a stable version of the SDK in -August, 2025. +This version contains minor breaking changes. +See the [release notes]( +https://github.com/modelcontextprotocol/go-sdk/releases/tag/v0.6.0) for details. -## Design +[![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) -The design doc for this SDK is at [design.md](./design/design.md), which was -initially reviewed at -[modelcontextprotocol/discussions/364](https://github.com/orgs/modelcontextprotocol/discussions/364). +This repository contains an implementation of the official Go software +development kit (SDK) for the Model Context Protocol (MCP). -Further design discussion should occur in -[issues](https://github.com/modelcontextprotocol/go-sdk/issues) (for concrete -proposals) or -[discussions](https://github.com/modelcontextprotocol/go-sdk/discussions) for -open-ended discussion. See CONTRIBUTING.md for details. +> [!IMPORTANT] +> The SDK is in release-candidate state, and is going to be tagged v1.0.0 +> soon (see https://github.com/modelcontextprotocol/go-sdk/issues/328). +> We do not anticipate significant API changes or instability. Please use it +> and [file issues](https://github.com/modelcontextprotocol/go-sdk/issues/new/choose). -## Package documentation +## Package / Feature documentation -The SDK consists of two importable packages: +The SDK consists of several importable packages: - The [`github.com/modelcontextprotocol/go-sdk/mcp`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp) package defines the primary APIs for constructing and using MCP clients and servers. - The - [`github.com/modelcontextprotocol/go-sdk/jsonschema`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonschema) - package provides an implementation of [JSON - Schema](https://json-schema.org/), used for MCP tool input and output schema. + [`github.com/modelcontextprotocol/go-sdk/jsonrpc`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/jsonrpc) package is for users implementing + their own transports. +- The + [`github.com/modelcontextprotocol/go-sdk/auth`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth) + package provides some primitives for supporting oauth. +- The + [`github.com/modelcontextprotocol/go-sdk/oauthex`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/oauthex) + package provides extensions to the OAuth protocol, such as ProtectedResourceMetadata. + +The SDK endeavors to implement the full MCP spec. The [`docs/`](/docs/) directory +contains feature documentation, mapping the MCP spec to the packages above. -## Example +## Getting started -In this example, an MCP client communicates with an MCP server running in a -sidecar process: +To get started creating an MCP server, create an `mcp.Server` instance, add +features to it, and then run it over an `mcp.Transport`. For example, this +server adds a single simple tool, and then connects it to clients over +stdin/stdout: + +%include server/server.go - + +To communicate with that server, create an `mcp.Client` and connect it to the +corresponding server, by running the server command and communicating over its +stdin/stdout: %include client/client.go - -Here's an example of the corresponding server component, which communicates -with its client over stdin/stdout: +The [`examples/`](/examples/) directory contains more example clients and +servers. -%include server/server.go - +## Contributing -The `examples/` directory contains more example clients and servers. +We welcome contributions to the SDK! Please see See +[CONTRIBUTING.md](/CONTRIBUTING.md) for details of how to contribute. -## Acknowledgements +## Acknowledgements / Alternatives -Several existing Go MCP SDKs inspired the development and design of this -official SDK, notably [mcp-go](https://github.com/mark3labs/mcp-go), authored -by Ed Zynda. We are grateful to Ed as well as the other contributors to mcp-go, -and to authors and contributors of other SDKs such as +Several third party Go MCP SDKs inspired the development and design of this +official SDK, and continue to be viable alternatives, notably +[mcp-go](https://github.com/mark3labs/mcp-go), originally authored by Ed Zynda. +We are grateful to Ed as well as the other contributors to mcp-go, and to +authors and contributors of other SDKs such as [mcp-golang](https://github.com/metoro-io/mcp-golang) and [go-mcp](https://github.com/ThinkInAIXYZ/go-mcp). Thanks to their work, there is a thriving ecosystem of Go MCP clients and servers. diff --git a/internal/readme/build.sh b/internal/readme/build.sh deleted file mode 100755 index b721d505..00000000 --- a/internal/readme/build.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/sh - -# Build README.md from the files in this directory. -# Must be invoked from the mcp directory. - -cd ../internal/readme - -outfile=../../README.md - -# Compile all the code used in the README. -go build -o /tmp/mcp-readme/ ./... -# Combine the code with the text in README.src.md. -# TODO: when at Go 1.24, use a tool directive for weave. -go run golang.org/x/example/internal/cmd/weave@latest README.src.md > $outfile - -if [[ $1 = '-preview' ]]; then - # Preview the README on GitHub. - # $HOME/markdown must be a github repo. - # Visit https://github.com/$HOME/markdown to see the result. - cp $outfile $HOME/markdown/ - (cd $HOME/markdown/ && git commit -m . README.md && git push) -fi diff --git a/internal/readme/client/client.go b/internal/readme/client/client.go index 44bc515c..9f357964 100644 --- a/internal/readme/client/client.go +++ b/internal/readme/client/client.go @@ -17,11 +17,11 @@ func main() { ctx := context.Background() // Create a new client, with no features. - client := mcp.NewClient("mcp-client", "v1.0.0", nil) + client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) // Connect to a server over stdin/stdout - transport := mcp.NewCommandTransport(exec.Command("myserver")) - session, err := client.Connect(ctx, transport) + transport := &mcp.CommandTransport{Command: exec.Command("myserver")} + session, err := client.Connect(ctx, transport, nil) if err != nil { log.Fatal(err) } @@ -44,4 +44,4 @@ func main() { } } -//!- +// !- diff --git a/internal/readme/doc.go b/internal/readme/doc.go new file mode 100644 index 00000000..34bc60c8 --- /dev/null +++ b/internal/readme/doc.go @@ -0,0 +1,9 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:generate go run golang.org/x/example/internal/cmd/weave@latest -o ../../README.md ./README.src.md + +// The readme package is used to generate README.md at the top-level of this +// repo. Regenerate the README with go generate. +package readme diff --git a/internal/readme/server/server.go b/internal/readme/server/server.go index 534e0798..e9996027 100644 --- a/internal/readme/server/server.go +++ b/internal/readme/server/server.go @@ -12,26 +12,24 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -type HiParams struct { - Name string `json:"name"` +type Input struct { + Name string `json:"name" jsonschema:"the name of the person to greet"` } -func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[HiParams]) (*mcp.CallToolResultFor[any], error) { - return &mcp.CallToolResultFor[any]{ - Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + params.Arguments.Name}}, - }, nil +type Output struct { + Greeting string `json:"greeting" jsonschema:"the greeting to tell to the user"` +} + +func SayHi(ctx context.Context, req *mcp.CallToolRequest, input Input) (*mcp.CallToolResult, Output, error) { + return nil, Output{Greeting: "Hi " + input.Name}, nil } func main() { // Create a server with a single tool. - server := mcp.NewServer("greeter", "v1.0.0", nil) - server.AddTools( - mcp.NewServerTool("greet", "say hi", SayHi, mcp.Input( - mcp.Property("name", mcp.Description("the name of the person to greet")), - )), - ) + server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v1.0.0"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) // Run the server over stdin/stdout, until the client disconnects - if err := server.Run(context.Background(), mcp.NewStdioTransport()); err != nil { + if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { log.Fatal(err) } } diff --git a/internal/testing/fake_auth_server.go b/internal/testing/fake_auth_server.go new file mode 100644 index 00000000..79fafe4e --- /dev/null +++ b/internal/testing/fake_auth_server.go @@ -0,0 +1,149 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package testing + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +const ( + authServerPort = ":8080" + issuer = "http://localhost" + authServerPort + tokenExpiry = time.Hour +) + +var jwtSigningKey = []byte("fake-secret-key") + +type authCodeInfo struct { + codeChallenge string + redirectURI string +} + +type state struct { + authCodes map[string]authCodeInfo +} + +// NewFakeAuthMux constructs a ServeMux that implements an OAuth 2.1 authentication +// server. It should be used with [httptest.NewTLSServer]. +func NewFakeAuthMux() *http.ServeMux { + s := &state{authCodes: make(map[string]authCodeInfo)} + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/oauth-authorization-server", s.handleMetadata) + mux.HandleFunc("/authorize", s.handleAuthorize) + mux.HandleFunc("/token", s.handleToken) + return mux +} + +func (s *state) handleMetadata(w http.ResponseWriter, r *http.Request) { + issuer := "https://localhost:" + r.URL.Port() + metadata := map[string]any{ + "issuer": issuer, + "authorization_endpoint": issuer + "/authorize", + "token_endpoint": issuer + "/token", + "jwks_uri": issuer + "/.well-known/jwks.json", + "scopes_supported": []string{"openid", "profile", "email"}, + "response_types_supported": []string{"code"}, + "grant_types_supported": []string{"authorization_code"}, + "token_endpoint_auth_methods_supported": []string{"none"}, + "code_challenge_methods_supported": []string{"S256"}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(metadata) +} + +func (s *state) handleAuthorize(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + responseType := query.Get("response_type") + redirectURI := query.Get("redirect_uri") + codeChallenge := query.Get("code_challenge") + codeChallengeMethod := query.Get("code_challenge_method") + + if responseType != "code" { + http.Error(w, "unsupported_response_type", http.StatusBadRequest) + return + } + if redirectURI == "" { + http.Error(w, "invalid_request (no redirect_uri)", http.StatusBadRequest) + return + } + if codeChallenge == "" || codeChallengeMethod != "S256" { + http.Error(w, "invalid_request (code challenge is not S256)", http.StatusBadRequest) + return + } + if query.Get("client_id") == "" { + http.Error(w, "invalid_request (missing client_id)", http.StatusBadRequest) + return + } + + authCode := "fake-auth-code-" + fmt.Sprintf("%d", time.Now().UnixNano()) + s.authCodes[authCode] = authCodeInfo{ + codeChallenge: codeChallenge, + redirectURI: redirectURI, + } + + redirectURL := fmt.Sprintf("%s?code=%s&state=%s", redirectURI, authCode, query.Get("state")) + http.Redirect(w, r, redirectURL, http.StatusFound) +} + +func (s *state) handleToken(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + grantType := r.Form.Get("grant_type") + code := r.Form.Get("code") + codeVerifier := r.Form.Get("code_verifier") + // Ignore redirect_uri; it is not required in 2.1. + // https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-13.html#redirect-uri-in-token-request + + if grantType != "authorization_code" { + http.Error(w, "unsupported_grant_type", http.StatusBadRequest) + return + } + + authCodeInfo, ok := s.authCodes[code] + if !ok { + http.Error(w, "invalid_grant", http.StatusBadRequest) + return + } + delete(s.authCodes, code) + + // PKCE verification + hasher := sha256.New() + hasher.Write([]byte(codeVerifier)) + calculatedChallenge := base64.RawURLEncoding.EncodeToString(hasher.Sum(nil)) + if calculatedChallenge != authCodeInfo.codeChallenge { + http.Error(w, "invalid_grant", http.StatusBadRequest) + return + } + + // Issue JWT + now := time.Now() + claims := jwt.MapClaims{ + "iss": issuer, + "sub": "fake-user-id", + "aud": "fake-client-id", + "exp": now.Add(tokenExpiry).Unix(), + "iat": now.Unix(), + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + accessToken, err := token.SignedString(jwtSigningKey) + if err != nil { + http.Error(w, "server_error", http.StatusInternalServerError) + return + } + + tokenResponse := map[string]any{ + "access_token": accessToken, + "token_type": "Bearer", + "expires_in": int(tokenExpiry.Seconds()), + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(tokenResponse) +} diff --git a/internal/util/util.go b/internal/util/util.go index f8c5baf9..4b5c325f 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -8,9 +8,7 @@ import ( "cmp" "fmt" "iter" - "reflect" "slices" - "strings" ) // Helpers below are copied from gopls' moremaps package. @@ -38,43 +36,6 @@ func KeySlice[M ~map[K]V, K comparable, V any](m M) []K { return r } -type JSONInfo struct { - Omit bool // unexported or first tag element is "-" - Name string // Go field name or first tag element. Empty if Omit is true. - Settings map[string]bool // "omitempty", "omitzero", etc. -} - -// FieldJSONInfo reports information about how encoding/json -// handles the given struct field. -// If the field is unexported, JSONInfo.Omit is true and no other JSONInfo field -// is populated. -// If the field is exported and has no tag, then Name is the field's name and all -// other fields are false. -// Otherwise, the information is obtained from the tag. -func FieldJSONInfo(f reflect.StructField) JSONInfo { - if !f.IsExported() { - return JSONInfo{Omit: true} - } - info := JSONInfo{Name: f.Name} - if tag, ok := f.Tag.Lookup("json"); ok { - name, rest, found := strings.Cut(tag, ",") - // "-" means omit, but "-," means the name is "-" - if name == "-" && !found { - return JSONInfo{Omit: true} - } - if name != "" { - info.Name = name - } - if len(rest) > 0 { - info.Settings = map[string]bool{} - for _, s := range strings.Split(rest, ",") { - info.Settings[s] = true - } - } - } - return info -} - // Wrapf wraps *errp with the given formatted message if *errp is not nil. func Wrapf(errp *error, format string, args ...any) { if *errp != nil { diff --git a/internal/util/util_test.go b/internal/util/util_test.go deleted file mode 100644 index 6a2b8676..00000000 --- a/internal/util/util_test.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package util - -import ( - "reflect" - "testing" -) - -func TestJSONInfo(t *testing.T) { - type S struct { - A int - B int `json:","` - C int `json:"-"` - D int `json:"-,"` - E int `json:"echo"` - F int `json:"foxtrot,omitempty"` - g int `json:"golf"` - } - want := []JSONInfo{ - {Name: "A"}, - {Name: "B"}, - {Omit: true}, - {Name: "-"}, - {Name: "echo"}, - {Name: "foxtrot", Settings: map[string]bool{"omitempty": true}}, - {Omit: true}, - } - tt := reflect.TypeFor[S]() - for i := range tt.NumField() { - got := FieldJSONInfo(tt.Field(i)) - if !reflect.DeepEqual(got, want[i]) { - t.Errorf("got %+v, want %+v", got, want[i]) - } - } -} diff --git a/jsonrpc/jsonrpc.go b/jsonrpc/jsonrpc.go new file mode 100644 index 00000000..1cf1202f --- /dev/null +++ b/jsonrpc/jsonrpc.go @@ -0,0 +1,39 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package jsonrpc exposes part of a JSON-RPC v2 implementation +// for use by mcp transport authors. +package jsonrpc + +import "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + +type ( + // ID is a JSON-RPC request ID. + ID = jsonrpc2.ID + // Message is a JSON-RPC message. + Message = jsonrpc2.Message + // Request is a JSON-RPC request. + Request = jsonrpc2.Request + // Response is a JSON-RPC response. + Response = jsonrpc2.Response +) + +// MakeID coerces the given Go value to an ID. The value is assumed to be the +// default JSON marshaling of a Request identifier -- nil, float64, or string. +// +// Returns an error if the value type was not a valid Request ID type. +func MakeID(v any) (ID, error) { + return jsonrpc2.MakeID(v) +} + +// EncodeMessage serializes a JSON-RPC message to its wire format. +func EncodeMessage(msg Message) ([]byte, error) { + return jsonrpc2.EncodeMessage(msg) +} + +// DecodeMessage deserializes JSON-RPC wire format data into a Message. +// It returns either a Request or Response based on the message content. +func DecodeMessage(data []byte) (Message, error) { + return jsonrpc2.DecodeMessage(data) +} diff --git a/jsonschema/annotations.go b/jsonschema/annotations.go deleted file mode 100644 index a7ede1c6..00000000 --- a/jsonschema/annotations.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import "maps" - -// An annotations tracks certain properties computed by keywords that are used by validation. -// ("Annotation" is the spec's term.) -// In particular, the unevaluatedItems and unevaluatedProperties keywords need to know which -// items and properties were evaluated (validated successfully). -type annotations struct { - allItems bool // all items were evaluated - endIndex int // 1+largest index evaluated by prefixItems - evaluatedIndexes map[int]bool // set of indexes evaluated by contains - allProperties bool // all properties were evaluated - evaluatedProperties map[string]bool // set of properties evaluated by various keywords -} - -// noteIndex marks i as evaluated. -func (a *annotations) noteIndex(i int) { - if a.evaluatedIndexes == nil { - a.evaluatedIndexes = map[int]bool{} - } - a.evaluatedIndexes[i] = true -} - -// noteEndIndex marks items with index less than end as evaluated. -func (a *annotations) noteEndIndex(end int) { - if end > a.endIndex { - a.endIndex = end - } -} - -// noteProperty marks prop as evaluated. -func (a *annotations) noteProperty(prop string) { - if a.evaluatedProperties == nil { - a.evaluatedProperties = map[string]bool{} - } - a.evaluatedProperties[prop] = true -} - -// noteProperties marks all the properties in props as evaluated. -func (a *annotations) noteProperties(props map[string]bool) { - a.evaluatedProperties = merge(a.evaluatedProperties, props) -} - -// merge adds b's annotations to a. -// a must not be nil. -func (a *annotations) merge(b *annotations) { - if b == nil { - return - } - if b.allItems { - a.allItems = true - } - if b.endIndex > a.endIndex { - a.endIndex = b.endIndex - } - a.evaluatedIndexes = merge(a.evaluatedIndexes, b.evaluatedIndexes) - if b.allProperties { - a.allProperties = true - } - a.evaluatedProperties = merge(a.evaluatedProperties, b.evaluatedProperties) -} - -// merge adds t's keys to s and returns s. -// If s is nil, it returns a copy of t. -func merge[K comparable](s, t map[K]bool) map[K]bool { - if s == nil { - return maps.Clone(t) - } - maps.Copy(s, t) - return s -} diff --git a/jsonschema/doc.go b/jsonschema/doc.go deleted file mode 100644 index f25b000a..00000000 --- a/jsonschema/doc.go +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -/* -Package jsonschema is an implementation of the [JSON Schema specification], -a JSON-based format for describing the structure of JSON data. -The package can be used to read schemas for code generation, and to validate data using the -draft 2020-12 specification. Validation with other drafts or custom meta-schemas -is not supported. - -Construct a [Schema] as you would any Go struct (for example, by writing a struct -literal), or unmarshal a JSON schema into a [Schema] in the usual way (with [encoding/json], -for instance). It can then be used for code generation or other purposes without -further processing. - -# Validation - -Before using a Schema to validate a JSON value, you must first resolve it by calling -[Schema.Resolve]. -The call [Resolved.Validate] on the result to validate a JSON value. -The value must be a Go value that looks like the result of unmarshaling a JSON -value into an [any] or a struct. For example, the JSON value - - {"name": "Al", "scores": [90, 80, 100]} - -could be represented as the Go value - - map[string]any{ - "name": "Al", - "scores": []any{90, 80, 100}, - } - -or as a value of this type: - - type Player struct { - Name string `json:"name"` - Scores []int `json:"scores"` - } - -# Inference - -The [For] function returns a [Schema] describing the given Go type. -The type cannot contain any function or channel types, and any map types must have a string key. -For example, calling For on the above Player type results in this schema: - - { - "properties": { - "name": { - "type": "string" - }, - "scores": { - "type": "array", - "items": {"type": "integer"} - } - "required": ["name", "scores"], - "additionalProperties": {"not": {}} - } - } - -# Deviations from the specification - -Regular expressions are processed with Go's regexp package, which differs from ECMA 262, -most significantly in not supporting back-references. -See [this table of differences] for more. - -The value of the "format" keyword is recorded in the Schema, but is ignored during validation. -It does not even produce [annotations]. - -[JSON Schema specification]: https://json-schema.org -[this table of differences] https://github.com/dlclark/regexp2?tab=readme-ov-file#compare-regexp-and-regexp2 -[annotations]: https://json-schema.org/draft/2020-12/json-schema-core#name-annotations -*/ -package jsonschema diff --git a/jsonschema/infer.go b/jsonschema/infer.go deleted file mode 100644 index 1334bdf1..00000000 --- a/jsonschema/infer.go +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -// This file contains functions that infer a schema from a Go type. - -package jsonschema - -import ( - "fmt" - "reflect" - - "github.com/modelcontextprotocol/go-sdk/internal/util" -) - -// For constructs a JSON schema object for the given type argument. -// -// It translates Go types into compatible JSON schema types, as follows: -// - Strings have schema type "string". -// - Bools have schema type "boolean". -// - Signed and unsigned integer types have schema type "integer". -// - Floating point types have schema type "number". -// - Slices and arrays have schema type "array", and a corresponding schema -// for items. -// - Maps with string key have schema type "object", and corresponding -// schema for additionalProperties. -// - Structs have schema type "object", and disallow additionalProperties. -// Their properties are derived from exported struct fields, using the -// struct field JSON name. Fields that are marked "omitempty" are -// considered optional; all other fields become required properties. -// -// For returns an error if t contains (possibly recursively) any of the following Go -// types, as they are incompatible with the JSON schema spec. -// - maps with key other than 'string' -// - function types -// - complex numbers -// - unsafe pointers -// -// The types must not have cycles. -func For[T any]() (*Schema, error) { - // TODO: consider skipping incompatible fields, instead of failing. - s, err := forType(reflect.TypeFor[T]()) - if err != nil { - var z T - return nil, fmt.Errorf("For[%T](): %w", z, err) - } - return s, nil -} - -func forType(t reflect.Type) (*Schema, error) { - // Follow pointers: the schema for *T is almost the same as for T, except that - // an explicit JSON "null" is allowed for the pointer. - allowNull := false - for t.Kind() == reflect.Pointer { - allowNull = true - t = t.Elem() - } - - var ( - s = new(Schema) - err error - ) - - switch t.Kind() { - case reflect.Bool: - s.Type = "boolean" - - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, - reflect.Uintptr: - s.Type = "integer" - - case reflect.Float32, reflect.Float64: - s.Type = "number" - - case reflect.Interface: - // Unrestricted - - case reflect.Map: - if t.Key().Kind() != reflect.String { - return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind()) - } - s.Type = "object" - s.AdditionalProperties, err = forType(t.Elem()) - if err != nil { - return nil, fmt.Errorf("computing map value schema: %v", err) - } - - case reflect.Slice, reflect.Array: - s.Type = "array" - s.Items, err = forType(t.Elem()) - if err != nil { - return nil, fmt.Errorf("computing element schema: %v", err) - } - if t.Kind() == reflect.Array { - s.MinItems = Ptr(t.Len()) - s.MaxItems = Ptr(t.Len()) - } - - case reflect.String: - s.Type = "string" - - case reflect.Struct: - s.Type = "object" - // no additional properties are allowed - s.AdditionalProperties = falseSchema() - - for i := range t.NumField() { - field := t.Field(i) - info := util.FieldJSONInfo(field) - if info.Omit { - continue - } - if s.Properties == nil { - s.Properties = make(map[string]*Schema) - } - s.Properties[info.Name], err = forType(field.Type) - if err != nil { - return nil, err - } - if !info.Settings["omitempty"] && !info.Settings["omitzero"] { - s.Required = append(s.Required, info.Name) - } - } - - default: - return nil, fmt.Errorf("type %v is unsupported by jsonschema", t) - } - if allowNull && s.Type != "" { - s.Types = []string{"null", s.Type} - s.Type = "" - } - return s, nil -} diff --git a/jsonschema/infer_test.go b/jsonschema/infer_test.go deleted file mode 100644 index 9325b832..00000000 --- a/jsonschema/infer_test.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema_test - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "github.com/modelcontextprotocol/go-sdk/jsonschema" -) - -func forType[T any]() *jsonschema.Schema { - s, err := jsonschema.For[T]() - if err != nil { - panic(err) - } - return s -} - -func TestForType(t *testing.T) { - type schema = jsonschema.Schema - tests := []struct { - name string - got *jsonschema.Schema - want *jsonschema.Schema - }{ - {"string", forType[string](), &schema{Type: "string"}}, - {"int", forType[int](), &schema{Type: "integer"}}, - {"int16", forType[int16](), &schema{Type: "integer"}}, - {"uint32", forType[int16](), &schema{Type: "integer"}}, - {"float64", forType[float64](), &schema{Type: "number"}}, - {"bool", forType[bool](), &schema{Type: "boolean"}}, - {"intmap", forType[map[string]int](), &schema{ - Type: "object", - AdditionalProperties: &schema{Type: "integer"}, - }}, - {"anymap", forType[map[string]any](), &schema{ - Type: "object", - AdditionalProperties: &schema{}, - }}, - { - "struct", - forType[struct { - F int `json:"f"` - G []float64 - P *bool - Skip string `json:"-"` - NoSkip string `json:",omitempty"` - unexported float64 - unexported2 int `json:"No"` - }](), - &schema{ - Type: "object", - Properties: map[string]*schema{ - "f": {Type: "integer"}, - "G": {Type: "array", Items: &schema{Type: "number"}}, - "P": {Types: []string{"null", "boolean"}}, - "NoSkip": {Type: "string"}, - }, - Required: []string{"f", "G", "P"}, - AdditionalProperties: &jsonschema.Schema{Not: &jsonschema.Schema{}}, - }, - }, - { - "no sharing", - forType[struct{ X, Y int }](), - &schema{ - Type: "object", - Properties: map[string]*schema{ - "X": {Type: "integer"}, - "Y": {Type: "integer"}, - }, - Required: []string{"X", "Y"}, - AdditionalProperties: &jsonschema.Schema{Not: &jsonschema.Schema{}}, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if diff := cmp.Diff(test.want, test.got, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Fatalf("ForType mismatch (-want +got):\n%s", diff) - } - // These schemas should all resolve. - if _, err := test.got.Resolve(nil); err != nil { - t.Fatalf("Resolving: %v", err) - } - }) - } -} diff --git a/jsonschema/json_pointer.go b/jsonschema/json_pointer.go deleted file mode 100644 index 7310b9b4..00000000 --- a/jsonschema/json_pointer.go +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -// This file implements JSON Pointers. -// A JSON Pointer is a path that refers to one JSON value within another. -// If the path is empty, it refers to the root value. -// Otherwise, it is a sequence of slash-prefixed strings, like "/points/1/x", -// selecting successive properties (for JSON objects) or items (for JSON arrays). -// For example, when applied to this JSON value: -// { -// "points": [ -// {"x": 1, "y": 2}, -// {"x": 3, "y": 4} -// ] -// } -// -// the JSON Pointer "/points/1/x" refers to the number 3. -// See the spec at https://datatracker.ietf.org/doc/html/rfc6901. - -package jsonschema - -import ( - "errors" - "fmt" - "reflect" - "strconv" - "strings" - - "github.com/modelcontextprotocol/go-sdk/internal/util" -) - -var ( - jsonPointerEscaper = strings.NewReplacer("~", "~0", "/", "~1") - jsonPointerUnescaper = strings.NewReplacer("~0", "~", "~1", "/") -) - -func escapeJSONPointerSegment(s string) string { - return jsonPointerEscaper.Replace(s) -} - -func unescapeJSONPointerSegment(s string) string { - return jsonPointerUnescaper.Replace(s) -} - -// parseJSONPointer splits a JSON Pointer into a sequence of segments. It doesn't -// convert strings to numbers, because that depends on the traversal: a segment -// is treated as a number when applied to an array, but a string when applied to -// an object. See section 4 of the spec. -func parseJSONPointer(ptr string) (segments []string, err error) { - if ptr == "" { - return nil, nil - } - if ptr[0] != '/' { - return nil, fmt.Errorf("JSON Pointer %q does not begin with '/'", ptr) - } - // Unlike file paths, consecutive slashes are not coalesced. - // Split is nicer than Cut here, because it gets a final "/" right. - segments = strings.Split(ptr[1:], "/") - if strings.Contains(ptr, "~") { - // Undo the simple escaping rules that allow one to include a slash in a segment. - for i := range segments { - segments[i] = unescapeJSONPointerSegment(segments[i]) - } - } - return segments, nil -} - -// dereferenceJSONPointer returns the Schema that sptr points to within s, -// or an error if none. -// This implementation suffices for JSON Schema: pointers are applied only to Schemas, -// and refer only to Schemas. -func dereferenceJSONPointer(s *Schema, sptr string) (_ *Schema, err error) { - defer util.Wrapf(&err, "JSON Pointer %q", sptr) - - segments, err := parseJSONPointer(sptr) - if err != nil { - return nil, err - } - v := reflect.ValueOf(s) - for _, seg := range segments { - switch v.Kind() { - case reflect.Pointer: - v = v.Elem() - if !v.IsValid() { - return nil, errors.New("navigated to nil reference") - } - fallthrough // if valid, can only be a pointer to a Schema - - case reflect.Struct: - // The segment must refer to a field in a Schema. - if v.Type() != reflect.TypeFor[Schema]() { - return nil, fmt.Errorf("navigated to non-Schema %s", v.Type()) - } - v = lookupSchemaField(v, seg) - if !v.IsValid() { - return nil, fmt.Errorf("no schema field %q", seg) - } - case reflect.Slice, reflect.Array: - // The segment must be an integer without leading zeroes that refers to an item in the - // slice or array. - if seg == "-" { - return nil, errors.New("the JSON Pointer array segment '-' is not supported") - } - if len(seg) > 1 && seg[0] == '0' { - return nil, fmt.Errorf("segment %q has leading zeroes", seg) - } - n, err := strconv.Atoi(seg) - if err != nil { - return nil, fmt.Errorf("invalid int: %q", seg) - } - if n < 0 || n >= v.Len() { - return nil, fmt.Errorf("index %d is out of bounds for array of length %d", n, v.Len()) - } - v = v.Index(n) - // Cannot be invalid. - case reflect.Map: - // The segment must be a key in the map. - v = v.MapIndex(reflect.ValueOf(seg)) - if !v.IsValid() { - return nil, fmt.Errorf("no key %q in map", seg) - } - default: - return nil, fmt.Errorf("value %s (%s) is not a schema, slice or map", v, v.Type()) - } - } - if s, ok := v.Interface().(*Schema); ok { - return s, nil - } - return nil, fmt.Errorf("does not refer to a schema, but to a %s", v.Type()) -} - -// lookupSchemaField returns the value of the field with the given name in v, -// or the zero value if there is no such field or it is not of type Schema or *Schema. -func lookupSchemaField(v reflect.Value, name string) reflect.Value { - if name == "type" { - // The "type" keyword may refer to Type or Types. - // At most one will be non-zero. - if t := v.FieldByName("Type"); !t.IsZero() { - return t - } - return v.FieldByName("Types") - } - if sf, ok := schemaFieldMap[name]; ok { - return v.FieldByIndex(sf.Index) - } - return reflect.Value{} -} diff --git a/jsonschema/json_pointer_test.go b/jsonschema/json_pointer_test.go deleted file mode 100644 index 54b84bed..00000000 --- a/jsonschema/json_pointer_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "strings" - "testing" -) - -func TestDereferenceJSONPointer(t *testing.T) { - s := &Schema{ - AllOf: []*Schema{{}, {}}, - Defs: map[string]*Schema{ - "": {Properties: map[string]*Schema{"": {}}}, - "A": {}, - "B": { - Defs: map[string]*Schema{ - "X": {}, - "Y": {}, - }, - }, - "/~": {}, - "~1": {}, - }, - } - - for _, tt := range []struct { - ptr string - want any - }{ - {"", s}, - {"/$defs/A", s.Defs["A"]}, - {"/$defs/B", s.Defs["B"]}, - {"/$defs/B/$defs/X", s.Defs["B"].Defs["X"]}, - {"/$defs//properties/", s.Defs[""].Properties[""]}, - {"/allOf/1", s.AllOf[1]}, - {"/$defs/~1~0", s.Defs["/~"]}, - {"/$defs/~01", s.Defs["~1"]}, - } { - got, err := dereferenceJSONPointer(s, tt.ptr) - if err != nil { - t.Fatal(err) - } - if got != tt.want { - t.Errorf("%s:\ngot %+v\nwant %+v", tt.ptr, got, tt.want) - } - } -} - -func TestDerefernceJSONPointerErrors(t *testing.T) { - s := &Schema{ - Type: "t", - Items: &Schema{}, - Required: []string{"a"}, - } - for _, tt := range []struct { - ptr string - want string // error must contain this string - }{ - {"x", "does not begin"}, // parse error: no initial '/' - {"/minItems", "does not refer to a schema"}, - {"/minItems/x", "navigated to nil"}, - {"/required/-", "not supported"}, - {"/required/01", "leading zeroes"}, - {"/required/x", "invalid int"}, - {"/required/1", "out of bounds"}, - {"/properties/x", "no key"}, - } { - _, err := dereferenceJSONPointer(s, tt.ptr) - if err == nil { - t.Errorf("%q: succeeded, want failure", tt.ptr) - } else if !strings.Contains(err.Error(), tt.want) { - t.Errorf("%q: error is %q, which does not contain %q", tt.ptr, err, tt.want) - } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/meta/applicator.json b/jsonschema/meta-schemas/draft2020-12/meta/applicator.json deleted file mode 100644 index f4775974..00000000 --- a/jsonschema/meta-schemas/draft2020-12/meta/applicator.json +++ /dev/null @@ -1,45 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/meta/applicator", - "$dynamicAnchor": "meta", - - "title": "Applicator vocabulary meta-schema", - "type": ["object", "boolean"], - "properties": { - "prefixItems": { "$ref": "#/$defs/schemaArray" }, - "items": { "$dynamicRef": "#meta" }, - "contains": { "$dynamicRef": "#meta" }, - "additionalProperties": { "$dynamicRef": "#meta" }, - "properties": { - "type": "object", - "additionalProperties": { "$dynamicRef": "#meta" }, - "default": {} - }, - "patternProperties": { - "type": "object", - "additionalProperties": { "$dynamicRef": "#meta" }, - "propertyNames": { "format": "regex" }, - "default": {} - }, - "dependentSchemas": { - "type": "object", - "additionalProperties": { "$dynamicRef": "#meta" }, - "default": {} - }, - "propertyNames": { "$dynamicRef": "#meta" }, - "if": { "$dynamicRef": "#meta" }, - "then": { "$dynamicRef": "#meta" }, - "else": { "$dynamicRef": "#meta" }, - "allOf": { "$ref": "#/$defs/schemaArray" }, - "anyOf": { "$ref": "#/$defs/schemaArray" }, - "oneOf": { "$ref": "#/$defs/schemaArray" }, - "not": { "$dynamicRef": "#meta" } - }, - "$defs": { - "schemaArray": { - "type": "array", - "minItems": 1, - "items": { "$dynamicRef": "#meta" } - } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/meta/content.json b/jsonschema/meta-schemas/draft2020-12/meta/content.json deleted file mode 100644 index 76e3760d..00000000 --- a/jsonschema/meta-schemas/draft2020-12/meta/content.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/meta/content", - "$dynamicAnchor": "meta", - - "title": "Content vocabulary meta-schema", - - "type": ["object", "boolean"], - "properties": { - "contentEncoding": { "type": "string" }, - "contentMediaType": { "type": "string" }, - "contentSchema": { "$dynamicRef": "#meta" } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/meta/core.json b/jsonschema/meta-schemas/draft2020-12/meta/core.json deleted file mode 100644 index 69186228..00000000 --- a/jsonschema/meta-schemas/draft2020-12/meta/core.json +++ /dev/null @@ -1,48 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/meta/core", - "$dynamicAnchor": "meta", - - "title": "Core vocabulary meta-schema", - "type": ["object", "boolean"], - "properties": { - "$id": { - "$ref": "#/$defs/uriReferenceString", - "$comment": "Non-empty fragments not allowed.", - "pattern": "^[^#]*#?$" - }, - "$schema": { "$ref": "#/$defs/uriString" }, - "$ref": { "$ref": "#/$defs/uriReferenceString" }, - "$anchor": { "$ref": "#/$defs/anchorString" }, - "$dynamicRef": { "$ref": "#/$defs/uriReferenceString" }, - "$dynamicAnchor": { "$ref": "#/$defs/anchorString" }, - "$vocabulary": { - "type": "object", - "propertyNames": { "$ref": "#/$defs/uriString" }, - "additionalProperties": { - "type": "boolean" - } - }, - "$comment": { - "type": "string" - }, - "$defs": { - "type": "object", - "additionalProperties": { "$dynamicRef": "#meta" } - } - }, - "$defs": { - "anchorString": { - "type": "string", - "pattern": "^[A-Za-z_][-A-Za-z0-9._]*$" - }, - "uriString": { - "type": "string", - "format": "uri" - }, - "uriReferenceString": { - "type": "string", - "format": "uri-reference" - } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/meta/format-annotation.json b/jsonschema/meta-schemas/draft2020-12/meta/format-annotation.json deleted file mode 100644 index 3479e669..00000000 --- a/jsonschema/meta-schemas/draft2020-12/meta/format-annotation.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/meta/format-annotation", - "$dynamicAnchor": "meta", - - "title": "Format vocabulary meta-schema for annotation results", - "type": ["object", "boolean"], - "properties": { - "format": { "type": "string" } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/meta/meta-data.json b/jsonschema/meta-schemas/draft2020-12/meta/meta-data.json deleted file mode 100644 index 4049ab21..00000000 --- a/jsonschema/meta-schemas/draft2020-12/meta/meta-data.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/meta/meta-data", - "$dynamicAnchor": "meta", - - "title": "Meta-data vocabulary meta-schema", - - "type": ["object", "boolean"], - "properties": { - "title": { - "type": "string" - }, - "description": { - "type": "string" - }, - "default": true, - "deprecated": { - "type": "boolean", - "default": false - }, - "readOnly": { - "type": "boolean", - "default": false - }, - "writeOnly": { - "type": "boolean", - "default": false - }, - "examples": { - "type": "array", - "items": true - } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/meta/unevaluated.json b/jsonschema/meta-schemas/draft2020-12/meta/unevaluated.json deleted file mode 100644 index 93779e54..00000000 --- a/jsonschema/meta-schemas/draft2020-12/meta/unevaluated.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/meta/unevaluated", - "$dynamicAnchor": "meta", - - "title": "Unevaluated applicator vocabulary meta-schema", - "type": ["object", "boolean"], - "properties": { - "unevaluatedItems": { "$dynamicRef": "#meta" }, - "unevaluatedProperties": { "$dynamicRef": "#meta" } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/meta/validation.json b/jsonschema/meta-schemas/draft2020-12/meta/validation.json deleted file mode 100644 index ebb75db7..00000000 --- a/jsonschema/meta-schemas/draft2020-12/meta/validation.json +++ /dev/null @@ -1,95 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/meta/validation", - "$dynamicAnchor": "meta", - - "title": "Validation vocabulary meta-schema", - "type": ["object", "boolean"], - "properties": { - "type": { - "anyOf": [ - { "$ref": "#/$defs/simpleTypes" }, - { - "type": "array", - "items": { "$ref": "#/$defs/simpleTypes" }, - "minItems": 1, - "uniqueItems": true - } - ] - }, - "const": true, - "enum": { - "type": "array", - "items": true - }, - "multipleOf": { - "type": "number", - "exclusiveMinimum": 0 - }, - "maximum": { - "type": "number" - }, - "exclusiveMaximum": { - "type": "number" - }, - "minimum": { - "type": "number" - }, - "exclusiveMinimum": { - "type": "number" - }, - "maxLength": { "$ref": "#/$defs/nonNegativeInteger" }, - "minLength": { "$ref": "#/$defs/nonNegativeIntegerDefault0" }, - "pattern": { - "type": "string", - "format": "regex" - }, - "maxItems": { "$ref": "#/$defs/nonNegativeInteger" }, - "minItems": { "$ref": "#/$defs/nonNegativeIntegerDefault0" }, - "uniqueItems": { - "type": "boolean", - "default": false - }, - "maxContains": { "$ref": "#/$defs/nonNegativeInteger" }, - "minContains": { - "$ref": "#/$defs/nonNegativeInteger", - "default": 1 - }, - "maxProperties": { "$ref": "#/$defs/nonNegativeInteger" }, - "minProperties": { "$ref": "#/$defs/nonNegativeIntegerDefault0" }, - "required": { "$ref": "#/$defs/stringArray" }, - "dependentRequired": { - "type": "object", - "additionalProperties": { - "$ref": "#/$defs/stringArray" - } - } - }, - "$defs": { - "nonNegativeInteger": { - "type": "integer", - "minimum": 0 - }, - "nonNegativeIntegerDefault0": { - "$ref": "#/$defs/nonNegativeInteger", - "default": 0 - }, - "simpleTypes": { - "enum": [ - "array", - "boolean", - "integer", - "null", - "number", - "object", - "string" - ] - }, - "stringArray": { - "type": "array", - "items": { "type": "string" }, - "uniqueItems": true, - "default": [] - } - } -} diff --git a/jsonschema/meta-schemas/draft2020-12/schema.json b/jsonschema/meta-schemas/draft2020-12/schema.json deleted file mode 100644 index d5e2d31c..00000000 --- a/jsonschema/meta-schemas/draft2020-12/schema.json +++ /dev/null @@ -1,58 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://json-schema.org/draft/2020-12/schema", - "$vocabulary": { - "https://json-schema.org/draft/2020-12/vocab/core": true, - "https://json-schema.org/draft/2020-12/vocab/applicator": true, - "https://json-schema.org/draft/2020-12/vocab/unevaluated": true, - "https://json-schema.org/draft/2020-12/vocab/validation": true, - "https://json-schema.org/draft/2020-12/vocab/meta-data": true, - "https://json-schema.org/draft/2020-12/vocab/format-annotation": true, - "https://json-schema.org/draft/2020-12/vocab/content": true - }, - "$dynamicAnchor": "meta", - - "title": "Core and Validation specifications meta-schema", - "allOf": [ - {"$ref": "meta/core"}, - {"$ref": "meta/applicator"}, - {"$ref": "meta/unevaluated"}, - {"$ref": "meta/validation"}, - {"$ref": "meta/meta-data"}, - {"$ref": "meta/format-annotation"}, - {"$ref": "meta/content"} - ], - "type": ["object", "boolean"], - "$comment": "This meta-schema also defines keywords that have appeared in previous drafts in order to prevent incompatible extensions as they remain in common use.", - "properties": { - "definitions": { - "$comment": "\"definitions\" has been replaced by \"$defs\".", - "type": "object", - "additionalProperties": { "$dynamicRef": "#meta" }, - "deprecated": true, - "default": {} - }, - "dependencies": { - "$comment": "\"dependencies\" has been split and replaced by \"dependentSchemas\" and \"dependentRequired\" in order to serve their differing semantics.", - "type": "object", - "additionalProperties": { - "anyOf": [ - { "$dynamicRef": "#meta" }, - { "$ref": "meta/validation#/$defs/stringArray" } - ] - }, - "deprecated": true, - "default": {} - }, - "$recursiveAnchor": { - "$comment": "\"$recursiveAnchor\" has been replaced by \"$dynamicAnchor\".", - "$ref": "meta/core#/$defs/anchorString", - "deprecated": true - }, - "$recursiveRef": { - "$comment": "\"$recursiveRef\" has been replaced by \"$dynamicRef\".", - "$ref": "meta/core#/$defs/uriReferenceString", - "deprecated": true - } - } -} diff --git a/jsonschema/resolve.go b/jsonschema/resolve.go deleted file mode 100644 index 58a44b2b..00000000 --- a/jsonschema/resolve.go +++ /dev/null @@ -1,471 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -// This file deals with preparing a schema for validation, including various checks, -// optimizations, and the resolution of cross-schema references. - -package jsonschema - -import ( - "errors" - "fmt" - "net/url" - "reflect" - "regexp" - "strings" -) - -// A Resolved consists of a [Schema] along with associated information needed to -// validate documents against it. -// A Resolved has been validated against its meta-schema, and all its references -// (the $ref and $dynamicRef keywords) have been resolved to their referenced Schemas. -// Call [Schema.Resolve] to obtain a Resolved from a Schema. -type Resolved struct { - root *Schema - // map from $ids to their schemas - resolvedURIs map[string]*Schema -} - -// Schema returns the schema that was resolved. -// It must not be modified. -func (r *Resolved) Schema() *Schema { return r.root } - -// A Loader reads and unmarshals the schema at uri, if any. -type Loader func(uri *url.URL) (*Schema, error) - -// ResolveOptions are options for [Schema.Resolve]. -type ResolveOptions struct { - // BaseURI is the URI relative to which the root schema should be resolved. - // If non-empty, must be an absolute URI (one that starts with a scheme). - // It is resolved (in the URI sense; see [url.ResolveReference]) with root's - // $id property. - // If the resulting URI is not absolute, then the schema cannot contain - // relative URI references. - BaseURI string - // Loader loads schemas that are referred to by a $ref but are not under the - // root schema (remote references). - // If nil, resolving a remote reference will return an error. - Loader Loader - // ValidateDefaults determines whether to validate values of "default" keywords - // against their schemas. - // The [JSON Schema specification] does not require this, but it is - // recommended if defaults will be used. - // - // [JSON Schema specification]: https://json-schema.org/understanding-json-schema/reference/annotations - ValidateDefaults bool -} - -// Resolve resolves all references within the schema and performs other tasks that -// prepare the schema for validation. -// If opts is nil, the default values are used. -func (root *Schema) Resolve(opts *ResolveOptions) (*Resolved, error) { - // There are up to five steps required to prepare a schema to validate. - // 1. Load: read the schema from somewhere and unmarshal it. - // This schema (root) may have been loaded or created in memory, but other schemas that - // come into the picture in step 4 will be loaded by the given loader. - // 2. Check: validate the schema against a meta-schema, and perform other well-formedness checks. - // Precompute some values along the way. - // 3. Resolve URIs: determine the base URI of the root and all its subschemas, and - // resolve (in the URI sense) all identifiers and anchors with their bases. This step results - // in a map from URIs to schemas within root. - // 4. Resolve references: all refs in the schemas are replaced with the schema they refer to. - // 5. (Optional.) If opts.ValidateDefaults is true, validate the defaults. - if root.path != "" { - return nil, fmt.Errorf("jsonschema: Resolve: %s already resolved", root) - } - r := &resolver{loaded: map[string]*Resolved{}} - if opts != nil { - r.opts = *opts - } - var base *url.URL - if r.opts.BaseURI == "" { - base = &url.URL{} // so we can call ResolveReference on it - } else { - var err error - base, err = url.Parse(r.opts.BaseURI) - if err != nil { - return nil, fmt.Errorf("parsing base URI: %w", err) - } - } - - if r.opts.Loader == nil { - r.opts.Loader = func(uri *url.URL) (*Schema, error) { - return nil, errors.New("cannot resolve remote schemas: no loader passed to Schema.Resolve") - } - } - - resolved, err := r.resolve(root, base) - if err != nil { - return nil, err - } - if r.opts.ValidateDefaults { - if err := resolved.validateDefaults(); err != nil { - return nil, err - } - } - // TODO: before we return, throw away anything we don't need for validation. - return resolved, nil -} - -// A resolver holds the state for resolution. -type resolver struct { - opts ResolveOptions - // A cache of loaded and partly resolved schemas. (They may not have had their - // refs resolved.) The cache ensures that the loader will never be called more - // than once with the same URI, and that reference cycles are handled properly. - loaded map[string]*Resolved -} - -func (r *resolver) resolve(s *Schema, baseURI *url.URL) (*Resolved, error) { - if baseURI.Fragment != "" { - return nil, fmt.Errorf("base URI %s must not have a fragment", baseURI) - } - if err := s.check(); err != nil { - return nil, err - } - - m, err := resolveURIs(s, baseURI) - if err != nil { - return nil, err - } - rs := &Resolved{root: s, resolvedURIs: m} - // Remember the schema by both the URI we loaded it from and its canonical name, - // which may differ if the schema has an $id. - // We must set the map before calling resolveRefs, or ref cycles will cause unbounded recursion. - r.loaded[baseURI.String()] = rs - r.loaded[s.uri.String()] = rs - - if err := r.resolveRefs(rs); err != nil { - return nil, err - } - return rs, nil -} - -func (root *Schema) check() error { - // Check for structural validity. Do this first and fail fast: - // bad structure will cause other code to panic. - if err := root.checkStructure(); err != nil { - return err - } - - var errs []error - report := func(err error) { errs = append(errs, err) } - - for ss := range root.all() { - ss.checkLocal(report) - } - return errors.Join(errs...) -} - -// checkStructure verifies that root and its subschemas form a tree. -// It also assigns each schema a unique path, to improve error messages. -func (root *Schema) checkStructure() error { - var check func(reflect.Value, []byte) error - check = func(v reflect.Value, path []byte) error { - // For the purpose of error messages, the root schema has path "root" - // and other schemas' paths are their JSON Pointer from the root. - p := "root" - if len(path) > 0 { - p = string(path) - } - s := v.Interface().(*Schema) - if s == nil { - return fmt.Errorf("jsonschema: schema at %s is nil", p) - } - if s.path != "" { - // We've seen s before. - // The schema graph at root is not a tree, but it needs to - // be because we assume a unique parent when we store a schema's base - // in the Schema. A cycle would also put Schema.all into an infinite - // recursion. - return fmt.Errorf("jsonschema: schemas at %s do not form a tree; %s appears more than once (also at %s)", - root, s.path, p) - } - s.path = p - - for _, info := range schemaFieldInfos { - fv := v.Elem().FieldByIndex(info.sf.Index) - switch info.sf.Type { - case schemaType: - // A field that contains an individual schema. - // A nil is valid: it just means the field isn't present. - if !fv.IsNil() { - if err := check(fv, fmt.Appendf(path, "/%s", info.jsonName)); err != nil { - return err - } - } - - case schemaSliceType: - for i := range fv.Len() { - if err := check(fv.Index(i), fmt.Appendf(path, "/%s/%d", info.jsonName, i)); err != nil { - return err - } - } - - case schemaMapType: - iter := fv.MapRange() - for iter.Next() { - key := escapeJSONPointerSegment(iter.Key().String()) - if err := check(iter.Value(), fmt.Appendf(path, "/%s/%s", info.jsonName, key)); err != nil { - return err - } - } - } - - } - return nil - } - - return check(reflect.ValueOf(root), make([]byte, 0, 256)) -} - -// checkLocal checks s for validity, independently of other schemas it may refer to. -// Since checking a regexp involves compiling it, checkLocal saves those compiled regexps -// in the schema for later use. -// It appends the errors it finds to errs. -func (s *Schema) checkLocal(report func(error)) { - addf := func(format string, args ...any) { - msg := fmt.Sprintf(format, args...) - report(fmt.Errorf("jsonschema.Schema: %s: %s", s, msg)) - } - - if s == nil { - addf("nil subschema") - return - } - if err := s.basicChecks(); err != nil { - report(err) - return - } - - // TODO: validate the schema's properties, - // ideally by jsonschema-validating it against the meta-schema. - - // Some properties are present so that Schemas can round-trip, but we do not - // validate them. - // Currently, it's just the $vocabulary property. - // As a special case, we can validate the 2020-12 meta-schema. - if s.Vocabulary != nil && s.Schema != draft202012 { - addf("cannot validate a schema with $vocabulary") - } - - // Check and compile regexps. - if s.Pattern != "" { - re, err := regexp.Compile(s.Pattern) - if err != nil { - addf("pattern: %v", err) - } else { - s.pattern = re - } - } - if len(s.PatternProperties) > 0 { - s.patternProperties = map[*regexp.Regexp]*Schema{} - for reString, subschema := range s.PatternProperties { - re, err := regexp.Compile(reString) - if err != nil { - addf("patternProperties[%q]: %v", reString, err) - continue - } - s.patternProperties[re] = subschema - } - } - - // Build a set of required properties, to avoid quadratic behavior when validating - // a struct. - if len(s.Required) > 0 { - s.isRequired = map[string]bool{} - for _, r := range s.Required { - s.isRequired[r] = true - } - } -} - -// resolveURIs resolves the ids and anchors in all the schemas of root, relative -// to baseURI. -// See https://json-schema.org/draft/2020-12/json-schema-core#section-8.2, section -// 8.2.1. - -// TODO(jba): dynamicAnchors (§8.2.2) -// -// Every schema has a base URI and a parent base URI. -// -// The parent base URI is the base URI of the lexically enclosing schema, or for -// a root schema, the URI it was loaded from or the one supplied to [Schema.Resolve]. -// -// If the schema has no $id property, the base URI of a schema is that of its parent. -// If the schema does have an $id, it must be a URI, possibly relative. The schema's -// base URI is the $id resolved (in the sense of [url.URL.ResolveReference]) against -// the parent base. -// -// As an example, consider this schema loaded from http://a.com/root.json (quotes omitted): -// -// { -// allOf: [ -// {$id: "sub1.json", minLength: 5}, -// {$id: "http://b.com", minimum: 10}, -// {not: {maximum: 20}} -// ] -// } -// -// The base URIs are as follows. Schema locations are expressed in the JSON Pointer notation. -// -// schema base URI -// root http://a.com/root.json -// allOf/0 http://a.com/sub1.json -// allOf/1 http://b.com (absolute $id; doesn't matter that it's not under the loaded URI) -// allOf/2 http://a.com/root.json (inherited from parent) -// allOf/2/not http://a.com/root.json (inherited from parent) -func resolveURIs(root *Schema, baseURI *url.URL) (map[string]*Schema, error) { - resolvedURIs := map[string]*Schema{} - - var resolve func(s, base *Schema) error - resolve = func(s, base *Schema) error { - // ids are scoped to the root. - if s.ID != "" { - // A non-empty ID establishes a new base. - idURI, err := url.Parse(s.ID) - if err != nil { - return err - } - if idURI.Fragment != "" { - return fmt.Errorf("$id %s must not have a fragment", s.ID) - } - // The base URI for this schema is its $id resolved against the parent base. - s.uri = base.uri.ResolveReference(idURI) - if !s.uri.IsAbs() { - return fmt.Errorf("$id %s does not resolve to an absolute URI (base is %s)", s.ID, s.base.uri) - } - resolvedURIs[s.uri.String()] = s - base = s // needed for anchors - } - s.base = base - - // Anchors and dynamic anchors are URI fragments that are scoped to their base. - // We treat them as keys in a map stored within the schema. - setAnchor := func(anchor string, dynamic bool) error { - if anchor != "" { - if _, ok := base.anchors[anchor]; ok { - return fmt.Errorf("duplicate anchor %q in %s", anchor, base.uri) - } - if base.anchors == nil { - base.anchors = map[string]anchorInfo{} - } - base.anchors[anchor] = anchorInfo{s, dynamic} - } - return nil - } - - setAnchor(s.Anchor, false) - setAnchor(s.DynamicAnchor, true) - - for c := range s.children() { - if err := resolve(c, base); err != nil { - return err - } - } - return nil - } - - // Set the root URI to the base for now. If the root has an $id, this will change. - root.uri = baseURI - // The original base, even if changed, is still a valid way to refer to the root. - resolvedURIs[baseURI.String()] = root - if err := resolve(root, root); err != nil { - return nil, err - } - return resolvedURIs, nil -} - -// resolveRefs replaces every ref in the schemas with the schema it refers to. -// A reference that doesn't resolve within the schema may refer to some other schema -// that needs to be loaded. -func (r *resolver) resolveRefs(rs *Resolved) error { - for s := range rs.root.all() { - if s.Ref != "" { - refSchema, _, err := r.resolveRef(rs, s, s.Ref) - if err != nil { - return err - } - // Whether or not the anchor referred to by $ref fragment is dynamic, - // the ref still treats it lexically. - s.resolvedRef = refSchema - } - if s.DynamicRef != "" { - refSchema, frag, err := r.resolveRef(rs, s, s.DynamicRef) - if err != nil { - return err - } - if frag != "" { - // The dynamic ref's fragment points to a dynamic anchor. - // We must resolve the fragment at validation time. - s.dynamicRefAnchor = frag - } else { - // There is no dynamic anchor in the lexically referenced schema, - // so the dynamic ref behaves like a lexical ref. - s.resolvedDynamicRef = refSchema - } - } - } - return nil -} - -// resolveRef resolves the reference ref, which is either s.Ref or s.DynamicRef. -func (r *resolver) resolveRef(rs *Resolved, s *Schema, ref string) (_ *Schema, dynamicFragment string, err error) { - refURI, err := url.Parse(ref) - if err != nil { - return nil, "", err - } - // URI-resolve the ref against the current base URI to get a complete URI. - refURI = s.base.uri.ResolveReference(refURI) - // The non-fragment part of a ref URI refers to the base URI of some schema. - // This part is the same for dynamic refs too: their non-fragment part resolves - // lexically. - u := *refURI - u.Fragment = "" - fraglessRefURI := &u - // Look it up locally. - referencedSchema := rs.resolvedURIs[fraglessRefURI.String()] - if referencedSchema == nil { - // The schema is remote. Maybe we've already loaded it. - // We assume that the non-fragment part of refURI refers to a top-level schema - // document. That is, we don't support the case exemplified by - // http://foo.com/bar.json/baz, where the document is in bar.json and - // the reference points to a subschema within it. - // TODO: support that case. - if lrs := r.loaded[fraglessRefURI.String()]; lrs != nil { - referencedSchema = lrs.root - } else { - // Try to load the schema. - ls, err := r.opts.Loader(fraglessRefURI) - if err != nil { - return nil, "", fmt.Errorf("loading %s: %w", fraglessRefURI, err) - } - lrs, err := r.resolve(ls, fraglessRefURI) - if err != nil { - return nil, "", err - } - referencedSchema = lrs.root - assert(referencedSchema != nil, "nil referenced schema") - } - } - - frag := refURI.Fragment - // Look up frag in refSchema. - // frag is either a JSON Pointer or the name of an anchor. - // A JSON Pointer is either the empty string or begins with a '/', - // whereas anchors are always non-empty strings that don't contain slashes. - if frag != "" && !strings.HasPrefix(frag, "/") { - info, found := referencedSchema.anchors[frag] - if !found { - return nil, "", fmt.Errorf("no anchor %q in %s", frag, s) - } - if info.dynamic { - dynamicFragment = frag - } - return info.schema, dynamicFragment, nil - } - // frag is a JSON Pointer. - s, err = dereferenceJSONPointer(referencedSchema, frag) - return s, "", err -} diff --git a/jsonschema/resolve_test.go b/jsonschema/resolve_test.go deleted file mode 100644 index 1b176bfa..00000000 --- a/jsonschema/resolve_test.go +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "errors" - "maps" - "net/url" - "regexp" - "slices" - "strings" - "testing" -) - -func TestSchemaStructure(t *testing.T) { - check := func(s *Schema, want string) { - t.Helper() - err := s.checkStructure() - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("checkStructure returned error %q, want %q", err, want) - } - } - - dag := &Schema{Type: "number"} - dag = &Schema{Items: dag, Contains: dag} - check(dag, "do not form a tree") - - tree := &Schema{Type: "number"} - tree.Items = tree - check(tree, "do not form a tree") - - sliceNil := &Schema{PrefixItems: []*Schema{nil}} - check(sliceNil, "is nil") - - sliceMap := &Schema{Properties: map[string]*Schema{"a": nil}} - check(sliceMap, "is nil") -} - -func TestCheckLocal(t *testing.T) { - for _, tt := range []struct { - s *Schema - want string // error must be non-nil and match this regexp - }{ - { - &Schema{Pattern: "]["}, - "regexp", - }, - { - &Schema{PatternProperties: map[string]*Schema{"*": {}}}, - "regexp", - }, - } { - _, err := tt.s.Resolve(nil) - if err == nil { - t.Errorf("%s: unexpectedly passed", tt.s.json()) - continue - } - if !regexp.MustCompile(tt.want).MatchString(err.Error()) { - t.Errorf("checkLocal returned error\n%q\nwanted it to match\n%s\nregexp: %s", - tt.s.json(), err, tt.want) - } - } -} - -func TestPaths(t *testing.T) { - // CheckStructure should assign paths to schemas. - // This test also verifies that Schema.all visits maps in sorted order. - root := &Schema{ - Type: "string", - PrefixItems: []*Schema{{Type: "int"}, {Items: &Schema{Type: "null"}}}, - Contains: &Schema{Properties: map[string]*Schema{ - "~1": {Type: "boolean"}, - "p": {}, - }}, - } - - type item struct { - s *Schema - p string - } - want := []item{ - {root, "root"}, - {root.Contains, "/contains"}, - {root.Contains.Properties["p"], "/contains/properties/p"}, - {root.Contains.Properties["~1"], "/contains/properties/~01"}, - {root.PrefixItems[0], "/prefixItems/0"}, - {root.PrefixItems[1], "/prefixItems/1"}, - {root.PrefixItems[1].Items, "/prefixItems/1/items"}, - } - if err := root.checkStructure(); err != nil { - t.Fatal(err) - } - - var got []item - for s := range root.all() { - got = append(got, item{s, s.path}) - } - if !slices.Equal(got, want) { - t.Errorf("\ngot %v\nwant %v", got, want) - } -} - -func TestResolveURIs(t *testing.T) { - for _, baseURI := range []string{"", "http://a.com"} { - t.Run(baseURI, func(t *testing.T) { - root := &Schema{ - ID: "http://b.com", - Items: &Schema{ - ID: "/foo.json", - }, - Contains: &Schema{ - ID: "/bar.json", - Anchor: "a", - DynamicAnchor: "da", - Items: &Schema{ - Anchor: "b", - Items: &Schema{ - // An ID shouldn't be a query param, but this tests - // resolving an ID with its parent. - ID: "?items", - Anchor: "c", - }, - }, - }, - } - base, err := url.Parse(baseURI) - if err != nil { - t.Fatal(err) - } - got, err := resolveURIs(root, base) - if err != nil { - t.Fatal(err) - } - - wantIDs := map[string]*Schema{ - baseURI: root, - "http://b.com/foo.json": root.Items, - "http://b.com/bar.json": root.Contains, - "http://b.com/bar.json?items": root.Contains.Items.Items, - } - if baseURI != root.ID { - wantIDs[root.ID] = root - } - wantAnchors := map[*Schema]map[string]anchorInfo{ - root.Contains: { - "a": anchorInfo{root.Contains, false}, - "da": anchorInfo{root.Contains, true}, - "b": anchorInfo{root.Contains.Items, false}, - }, - root.Contains.Items.Items: { - "c": anchorInfo{root.Contains.Items.Items, false}, - }, - } - - gotKeys := slices.Sorted(maps.Keys(got)) - wantKeys := slices.Sorted(maps.Keys(wantIDs)) - if !slices.Equal(gotKeys, wantKeys) { - t.Errorf("ID keys:\ngot %q\nwant %q", gotKeys, wantKeys) - } - if !maps.Equal(got, wantIDs) { - t.Errorf("IDs:\ngot %+v\n\nwant %+v", got, wantIDs) - } - for s := range root.all() { - if want := wantAnchors[s]; want != nil { - if got := s.anchors; !maps.Equal(got, want) { - t.Errorf("anchors:\ngot %+v\n\nwant %+v", got, want) - } - } else if s.anchors != nil { - t.Errorf("non-nil anchors for %s", s) - } - } - }) - } -} - -func TestRefCycle(t *testing.T) { - // Verify that cycles of refs are OK. - // The test suite doesn't check this, surprisingly. - schemas := map[string]*Schema{ - "root": {Ref: "a"}, - "a": {Ref: "b"}, - "b": {Ref: "a"}, - } - - loader := func(uri *url.URL) (*Schema, error) { - s, ok := schemas[uri.Path[1:]] - if !ok { - return nil, errors.New("not found") - } - return s, nil - } - - rs, err := schemas["root"].Resolve(&ResolveOptions{Loader: loader}) - if err != nil { - t.Fatal(err) - } - - check := func(s *Schema, key string) { - t.Helper() - if s.resolvedRef != schemas[key] { - t.Errorf("%s resolvedRef != schemas[%q]", s.json(), key) - } - } - - check(rs.root, "a") - check(schemas["a"], "b") - check(schemas["b"], "a") -} diff --git a/jsonschema/schema.go b/jsonschema/schema.go deleted file mode 100644 index 26623f1b..00000000 --- a/jsonschema/schema.go +++ /dev/null @@ -1,445 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "bytes" - "cmp" - "encoding/json" - "errors" - "fmt" - "iter" - "maps" - "math" - "net/url" - "reflect" - "regexp" - "slices" - - "github.com/modelcontextprotocol/go-sdk/internal/util" -) - -// A Schema is a JSON schema object. -// It corresponds to the 2020-12 draft, as described in https://json-schema.org/draft/2020-12, -// specifically: -// - https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-01 -// - https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01 -// -// A Schema value may have non-zero values for more than one field: -// all relevant non-zero fields are used for validation. -// There is one exception to provide more Go type-safety: the Type and Types fields -// are mutually exclusive. -// -// Since this struct is a Go representation of a JSON value, it inherits JSON's -// distinction between nil and empty. Nil slices and maps are considered absent, -// but empty ones are present and affect validation. For example, -// -// Schema{Enum: nil} -// -// is equivalent to an empty schema, so it validates every instance. But -// -// Schema{Enum: []any{}} -// -// requires equality to some slice element, so it vacuously rejects every instance. -type Schema struct { - // core - ID string `json:"$id,omitempty"` - Schema string `json:"$schema,omitempty"` - Ref string `json:"$ref,omitempty"` - Comment string `json:"$comment,omitempty"` - Defs map[string]*Schema `json:"$defs,omitempty"` - // definitions is deprecated but still allowed. It is a synonym for $defs. - Definitions map[string]*Schema `json:"definitions,omitempty"` - - Anchor string `json:"$anchor,omitempty"` - DynamicAnchor string `json:"$dynamicAnchor,omitempty"` - DynamicRef string `json:"$dynamicRef,omitempty"` - Vocabulary map[string]bool `json:"$vocabulary,omitempty"` - - // metadata - Title string `json:"title,omitempty"` - Description string `json:"description,omitempty"` - Default json.RawMessage `json:"default,omitempty"` - Deprecated bool `json:"deprecated,omitempty"` - ReadOnly bool `json:"readOnly,omitempty"` - WriteOnly bool `json:"writeOnly,omitempty"` - Examples []any `json:"examples,omitempty"` - - // validation - // Use Type for a single type, or Types for multiple types; never both. - Type string `json:"-"` - Types []string `json:"-"` - Enum []any `json:"enum,omitempty"` - // Const is *any because a JSON null (Go nil) is a valid value. - Const *any `json:"const,omitempty"` - MultipleOf *float64 `json:"multipleOf,omitempty"` - Minimum *float64 `json:"minimum,omitempty"` - Maximum *float64 `json:"maximum,omitempty"` - ExclusiveMinimum *float64 `json:"exclusiveMinimum,omitempty"` - ExclusiveMaximum *float64 `json:"exclusiveMaximum,omitempty"` - MinLength *int `json:"minLength,omitempty"` - MaxLength *int `json:"maxLength,omitempty"` - Pattern string `json:"pattern,omitempty"` - - // arrays - PrefixItems []*Schema `json:"prefixItems,omitempty"` - Items *Schema `json:"items,omitempty"` - MinItems *int `json:"minItems,omitempty"` - MaxItems *int `json:"maxItems,omitempty"` - AdditionalItems *Schema `json:"additionalItems,omitempty"` - UniqueItems bool `json:"uniqueItems,omitempty"` - Contains *Schema `json:"contains,omitempty"` - MinContains *int `json:"minContains,omitempty"` // *int, not int: default is 1, not 0 - MaxContains *int `json:"maxContains,omitempty"` - UnevaluatedItems *Schema `json:"unevaluatedItems,omitempty"` - - // objects - MinProperties *int `json:"minProperties,omitempty"` - MaxProperties *int `json:"maxProperties,omitempty"` - Required []string `json:"required,omitempty"` - DependentRequired map[string][]string `json:"dependentRequired,omitempty"` - Properties map[string]*Schema `json:"properties,omitempty"` - PatternProperties map[string]*Schema `json:"patternProperties,omitempty"` - AdditionalProperties *Schema `json:"additionalProperties,omitempty"` - PropertyNames *Schema `json:"propertyNames,omitempty"` - UnevaluatedProperties *Schema `json:"unevaluatedProperties,omitempty"` - - // logic - AllOf []*Schema `json:"allOf,omitempty"` - AnyOf []*Schema `json:"anyOf,omitempty"` - OneOf []*Schema `json:"oneOf,omitempty"` - Not *Schema `json:"not,omitempty"` - - // conditional - If *Schema `json:"if,omitempty"` - Then *Schema `json:"then,omitempty"` - Else *Schema `json:"else,omitempty"` - DependentSchemas map[string]*Schema `json:"dependentSchemas,omitempty"` - - // other - // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.8 - ContentEncoding string `json:"contentEncoding,omitempty"` - ContentMediaType string `json:"contentMediaType,omitempty"` - ContentSchema *Schema `json:"contentSchema,omitempty"` - - // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.7 - Format string `json:"format,omitempty"` - - // Extra allows for additional keywords beyond those specified. - Extra map[string]any `json:"-"` - - // computed fields - - // This schema's base schema. - // If the schema is the root or has an ID, its base is itself. - // Otherwise, its base is the innermost enclosing schema whose base - // is itself. - // Intuitively, a base schema is one that can be referred to with a - // fragmentless URI. - base *Schema - - // The URI for the schema, if it is the root or has an ID. - // Otherwise nil. - // Invariants: - // s.base.uri != nil. - // s.base == s <=> s.uri != nil - uri *url.URL - - // The JSON Pointer path from the root schema to here. - // Used in errors. - path string - - // The schema to which Ref refers. - resolvedRef *Schema - - // If the schema has a dynamic ref, exactly one of the next two fields - // will be non-zero after successful resolution. - // The schema to which the dynamic ref refers when it acts lexically. - resolvedDynamicRef *Schema - // The anchor to look up on the stack when the dynamic ref acts dynamically. - dynamicRefAnchor string - - // Map from anchors to subschemas. - anchors map[string]anchorInfo - - // compiled regexps - pattern *regexp.Regexp - patternProperties map[*regexp.Regexp]*Schema - - // the set of required properties - isRequired map[string]bool -} - -// falseSchema returns a new Schema tree that fails to validate any value. -func falseSchema() *Schema { - return &Schema{Not: &Schema{}} -} - -// anchorInfo records the subschema to which an anchor refers, and whether -// the anchor keyword is $anchor or $dynamicAnchor. -type anchorInfo struct { - schema *Schema - dynamic bool -} - -// String returns a short description of the schema. -func (s *Schema) String() string { - if s.uri != nil { - if u := s.uri.String(); u != "" { - return u - } - } - if a := cmp.Or(s.Anchor, s.DynamicAnchor); a != "" { - return fmt.Sprintf("%q, anchor %s", s.base.uri.String(), a) - } - if s.path != "" { - return s.path - } - return "" -} - -// ResolvedRef returns the Schema to which this schema's $ref keyword -// refers, or nil if it doesn't have a $ref. -// It returns nil if this schema has not been resolved, meaning that -// [Schema.Resolve] was called on it or one of its ancestors. -func (s *Schema) ResolvedRef() *Schema { - return s.resolvedRef -} - -func (s *Schema) basicChecks() error { - if s.Type != "" && s.Types != nil { - return errors.New("both Type and Types are set; at most one should be") - } - if s.Defs != nil && s.Definitions != nil { - return errors.New("both Defs and Definitions are set; at most one should be") - } - return nil -} - -type schemaWithoutMethods Schema // doesn't implement json.{Unm,M}arshaler - -func (s *Schema) MarshalJSON() ([]byte, error) { - if err := s.basicChecks(); err != nil { - return nil, err - } - // Marshal either Type or Types as "type". - var typ any - switch { - case s.Type != "": - typ = s.Type - case s.Types != nil: - typ = s.Types - } - ms := struct { - Type any `json:"type,omitempty"` - *schemaWithoutMethods - }{ - Type: typ, - schemaWithoutMethods: (*schemaWithoutMethods)(s), - } - return marshalStructWithMap(&ms, "Extra") -} - -func (s *Schema) UnmarshalJSON(data []byte) error { - // A JSON boolean is a valid schema. - var b bool - if err := json.Unmarshal(data, &b); err == nil { - if b { - // true is the empty schema, which validates everything. - *s = Schema{} - } else { - // false is the schema that validates nothing. - *s = *falseSchema() - } - return nil - } - - ms := struct { - Type json.RawMessage `json:"type,omitempty"` - Const json.RawMessage `json:"const,omitempty"` - MinLength *integer `json:"minLength,omitempty"` - MaxLength *integer `json:"maxLength,omitempty"` - MinItems *integer `json:"minItems,omitempty"` - MaxItems *integer `json:"maxItems,omitempty"` - MinProperties *integer `json:"minProperties,omitempty"` - MaxProperties *integer `json:"maxProperties,omitempty"` - MinContains *integer `json:"minContains,omitempty"` - MaxContains *integer `json:"maxContains,omitempty"` - - *schemaWithoutMethods - }{ - schemaWithoutMethods: (*schemaWithoutMethods)(s), - } - if err := unmarshalStructWithMap(data, &ms, "Extra"); err != nil { - return err - } - // Unmarshal "type" as either Type or Types. - var err error - if len(ms.Type) > 0 { - switch ms.Type[0] { - case '"': - err = json.Unmarshal(ms.Type, &s.Type) - case '[': - err = json.Unmarshal(ms.Type, &s.Types) - default: - err = fmt.Errorf(`invalid value for "type": %q`, ms.Type) - } - } - if err != nil { - return err - } - - unmarshalAnyPtr := func(p **any, raw json.RawMessage) error { - if len(raw) == 0 { - return nil - } - if bytes.Equal(raw, []byte("null")) { - *p = new(any) - return nil - } - return json.Unmarshal(raw, p) - } - - // Setting Const to a pointer to null will marshal properly, but won't - // unmarshal: the *any is set to nil, not a pointer to nil. - if err := unmarshalAnyPtr(&s.Const, ms.Const); err != nil { - return err - } - - set := func(dst **int, src *integer) { - if src != nil { - *dst = Ptr(int(*src)) - } - } - - set(&s.MinLength, ms.MinLength) - set(&s.MaxLength, ms.MaxLength) - set(&s.MinItems, ms.MinItems) - set(&s.MaxItems, ms.MaxItems) - set(&s.MinProperties, ms.MinProperties) - set(&s.MaxProperties, ms.MaxProperties) - set(&s.MinContains, ms.MinContains) - set(&s.MaxContains, ms.MaxContains) - - return nil -} - -type integer int32 // for the integer-valued fields of Schema - -func (ip *integer) UnmarshalJSON(data []byte) error { - if len(data) == 0 { - // nothing to do - return nil - } - // If there is a decimal point, src is a floating-point number. - var i int64 - if bytes.ContainsRune(data, '.') { - var f float64 - if err := json.Unmarshal(data, &f); err != nil { - return errors.New("not a number") - } - i = int64(f) - if float64(i) != f { - return errors.New("not an integer value") - } - } else { - if err := json.Unmarshal(data, &i); err != nil { - return errors.New("cannot be unmarshaled into an int") - } - } - // Ensure behavior is the same on both 32-bit and 64-bit systems. - if i < math.MinInt32 || i > math.MaxInt32 { - return errors.New("integer is out of range") - } - *ip = integer(i) - return nil -} - -// Ptr returns a pointer to a new variable whose value is x. -func Ptr[T any](x T) *T { return &x } - -// every applies f preorder to every schema under s including s. -// The second argument to f is the path to the schema appended to the argument path. -// It stops when f returns false. -func (s *Schema) every(f func(*Schema) bool) bool { - return f(s) && s.everyChild(func(s *Schema) bool { return s.every(f) }) -} - -// everyChild reports whether f is true for every immediate child schema of s. -func (s *Schema) everyChild(f func(*Schema) bool) bool { - v := reflect.ValueOf(s) - for _, info := range schemaFieldInfos { - fv := v.Elem().FieldByIndex(info.sf.Index) - switch info.sf.Type { - case schemaType: - // A field that contains an individual schema. A nil is valid: it just means the field isn't present. - c := fv.Interface().(*Schema) - if c != nil && !f(c) { - return false - } - - case schemaSliceType: - slice := fv.Interface().([]*Schema) - for _, c := range slice { - if !f(c) { - return false - } - } - - case schemaMapType: - // Sort keys for determinism. - m := fv.Interface().(map[string]*Schema) - for _, k := range slices.Sorted(maps.Keys(m)) { - if !f(m[k]) { - return false - } - } - } - } - return true -} - -// all wraps every in an iterator. -func (s *Schema) all() iter.Seq[*Schema] { - return func(yield func(*Schema) bool) { s.every(yield) } -} - -// children wraps everyChild in an iterator. -func (s *Schema) children() iter.Seq[*Schema] { - return func(yield func(*Schema) bool) { s.everyChild(yield) } -} - -var ( - schemaType = reflect.TypeFor[*Schema]() - schemaSliceType = reflect.TypeFor[[]*Schema]() - schemaMapType = reflect.TypeFor[map[string]*Schema]() -) - -type structFieldInfo struct { - sf reflect.StructField - jsonName string -} - -var ( - // the visible fields of Schema that have a JSON name, sorted by that name - schemaFieldInfos []structFieldInfo - // map from JSON name to field - schemaFieldMap = map[string]reflect.StructField{} -) - -func init() { - for _, sf := range reflect.VisibleFields(reflect.TypeFor[Schema]()) { - info := util.FieldJSONInfo(sf) - if !info.Omit { - schemaFieldInfos = append(schemaFieldInfos, structFieldInfo{sf, info.Name}) - } - } - slices.SortFunc(schemaFieldInfos, func(i1, i2 structFieldInfo) int { - return cmp.Compare(i1.jsonName, i2.jsonName) - }) - for _, info := range schemaFieldInfos { - schemaFieldMap[info.jsonName] = info.sf - } -} diff --git a/jsonschema/schema_test.go b/jsonschema/schema_test.go deleted file mode 100644 index 4ceb1ee1..00000000 --- a/jsonschema/schema_test.go +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "encoding/json" - "fmt" - "math" - "regexp" - "testing" -) - -func TestGoRoundTrip(t *testing.T) { - // Verify that Go representations round-trip. - for _, s := range []*Schema{ - {Type: "null"}, - {Types: []string{"null", "number"}}, - {Type: "string", MinLength: Ptr(20)}, - {Minimum: Ptr(20.0)}, - {Items: &Schema{Type: "integer"}}, - {Const: Ptr(any(0))}, - {Const: Ptr(any(nil))}, - {Const: Ptr(any([]int{}))}, - {Const: Ptr(any(map[string]any{}))}, - {Default: mustMarshal(1)}, - {Default: mustMarshal(nil)}, - {Extra: map[string]any{"test": "value"}}, - } { - data, err := json.Marshal(s) - if err != nil { - t.Fatal(err) - } - var got *Schema - mustUnmarshal(t, data, &got) - if !Equal(got, s) { - t.Errorf("got %s, want %s", got.json(), s.json()) - if got.Const != nil && s.Const != nil { - t.Logf("Consts: got %#v (%[1]T), want %#v (%[2]T)", *got.Const, *s.Const) - } - } - } -} - -func TestJSONRoundTrip(t *testing.T) { - // Verify that JSON texts for schemas marshal into equivalent forms. - // We don't expect everything to round-trip perfectly. For example, "true" and "false" - // will turn into their object equivalents. - // But most things should. - // Some of these cases test Schema.{UnM,M}arshalJSON. - // Most of others follow from the behavior of encoding/json, but they are still - // valuable as regression tests of this package's behavior. - for _, tt := range []struct { - in, want string - }{ - {`true`, `{}`}, // boolean schemas become object schemas - {`false`, `{"not":{}}`}, - {`{"type":"", "enum":null}`, `{}`}, // empty fields are omitted - {`{"minimum":1}`, `{"minimum":1}`}, - {`{"minimum":1.0}`, `{"minimum":1}`}, // floating-point integers lose their fractional part - {`{"minLength":1.0}`, `{"minLength":1}`}, // some floats are unmarshaled into ints, but you can't tell - { - // map keys are sorted - `{"$vocabulary":{"b":true, "a":false}}`, - `{"$vocabulary":{"a":false,"b":true}}`, - }, - {`{"unk":0}`, `{"unk":0}`}, // unknown fields are not dropped - { - // known and unknown fields are not dropped - // note that the order will be by the declaration order in the anonymous struct inside MarshalJSON - `{"comment":"test","type":"example","unk":0}`, - `{"type":"example","comment":"test","unk":0}`, - }, - {`{"extra":0}`, `{"extra":0}`}, // extra is not a special keyword and should not be dropped - {`{"Extra":0}`, `{"Extra":0}`}, // Extra is not a special keyword and should not be dropped - } { - var s Schema - mustUnmarshal(t, []byte(tt.in), &s) - data, err := json.Marshal(&s) - if err != nil { - t.Fatal(err) - } - if got := string(data); got != tt.want { - t.Errorf("%s:\ngot %s\nwant %s", tt.in, got, tt.want) - } - } -} - -func TestUnmarshalErrors(t *testing.T) { - for _, tt := range []struct { - in string - want string // error must match this regexp - }{ - {`1`, "cannot unmarshal number"}, - {`{"type":1}`, `invalid value for "type"`}, - {`{"minLength":1.5}`, `not an integer value`}, - {`{"maxLength":1.5}`, `not an integer value`}, - {`{"minItems":1.5}`, `not an integer value`}, - {`{"maxItems":1.5}`, `not an integer value`}, - {`{"minProperties":1.5}`, `not an integer value`}, - {`{"maxProperties":1.5}`, `not an integer value`}, - {`{"minContains":1.5}`, `not an integer value`}, - {`{"maxContains":1.5}`, `not an integer value`}, - {fmt.Sprintf(`{"maxContains":%d}`, int64(math.MaxInt32+1)), `out of range`}, - {`{"minLength":9e99}`, `cannot be unmarshaled`}, - {`{"minLength":"1.5"}`, `not a number`}, - } { - var s Schema - err := json.Unmarshal([]byte(tt.in), &s) - if err == nil { - t.Fatalf("%s: no error but expected one", tt.in) - } - if !regexp.MustCompile(tt.want).MatchString(err.Error()) { - t.Errorf("%s: error %q does not match %q", tt.in, err, tt.want) - } - - } -} - -func mustUnmarshal(t *testing.T, data []byte, ptr any) { - t.Helper() - if err := json.Unmarshal(data, ptr); err != nil { - t.Fatal(err) - } -} - -// json returns the schema in json format. -func (s *Schema) json() string { - data, err := json.Marshal(s) - if err != nil { - return fmt.Sprintf("", err) - } - return string(data) -} - -// json returns the schema in json format, indented. -func (s *Schema) jsonIndent() string { - data, err := json.MarshalIndent(s, "", " ") - if err != nil { - return fmt.Sprintf("", err) - } - return string(data) -} diff --git a/jsonschema/testdata/draft2020-12/README.md b/jsonschema/testdata/draft2020-12/README.md deleted file mode 100644 index 09ae5704..00000000 --- a/jsonschema/testdata/draft2020-12/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# JSON Schema test suite for 2020-12 - -These files were copied from -https://github.com/json-schema-org/JSON-Schema-Test-Suite/tree/83e866b46c9f9e7082fd51e83a61c5f2145a1ab7/tests/draft2020-12. diff --git a/jsonschema/testdata/draft2020-12/additionalProperties.json b/jsonschema/testdata/draft2020-12/additionalProperties.json deleted file mode 100644 index 9618575e..00000000 --- a/jsonschema/testdata/draft2020-12/additionalProperties.json +++ /dev/null @@ -1,219 +0,0 @@ -[ - { - "description": - "additionalProperties being false does not allow other properties", - "specification": [ { "core":"10.3.2.3", "quote": "The value of \"additionalProperties\" MUST be a valid JSON Schema. Boolean \"false\" forbids everything." } ], - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": {"foo": {}, "bar": {}}, - "patternProperties": { "^v": {} }, - "additionalProperties": false - }, - "tests": [ - { - "description": "no additional properties is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "an additional property is invalid", - "data": {"foo" : 1, "bar" : 2, "quux" : "boom"}, - "valid": false - }, - { - "description": "ignores arrays", - "data": [1, 2, 3], - "valid": true - }, - { - "description": "ignores strings", - "data": "foobarbaz", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - }, - { - "description": "patternProperties are not additional properties", - "data": {"foo":1, "vroom": 2}, - "valid": true - } - ] - }, - { - "description": "non-ASCII pattern with additionalProperties", - "specification": [ { "core":"10.3.2.3"} ], - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "patternProperties": {"^á": {}}, - "additionalProperties": false - }, - "tests": [ - { - "description": "matching the pattern is valid", - "data": {"ármányos": 2}, - "valid": true - }, - { - "description": "not matching the pattern is invalid", - "data": {"élmény": 2}, - "valid": false - } - ] - }, - { - "description": "additionalProperties with schema", - "specification": [ { "core":"10.3.2.3", "quote": "The value of \"additionalProperties\" MUST be a valid JSON Schema." } ], - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": {"foo": {}, "bar": {}}, - "additionalProperties": {"type": "boolean"} - }, - "tests": [ - { - "description": "no additional properties is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "an additional valid property is valid", - "data": {"foo" : 1, "bar" : 2, "quux" : true}, - "valid": true - }, - { - "description": "an additional invalid property is invalid", - "data": {"foo" : 1, "bar" : 2, "quux" : 12}, - "valid": false - } - ] - }, - { - "description": "additionalProperties can exist by itself", - "specification": [ { "core":"10.3.2.3", "quote": "With no other applicator applying to object instances. This validates all the instance values irrespective of their property names" } ], - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "additionalProperties": {"type": "boolean"} - }, - "tests": [ - { - "description": "an additional valid property is valid", - "data": {"foo" : true}, - "valid": true - }, - { - "description": "an additional invalid property is invalid", - "data": {"foo" : 1}, - "valid": false - } - ] - }, - { - "description": "additionalProperties are allowed by default", - "specification": [ { "core":"10.3.2.3", "quote": "Omitting this keyword has the same assertion behavior as an empty schema." } ], - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": {"foo": {}, "bar": {}} - }, - "tests": [ - { - "description": "additional properties are allowed", - "data": {"foo": 1, "bar": 2, "quux": true}, - "valid": true - } - ] - }, - { - "description": "additionalProperties does not look in applicators", - "specification":[ { "core": "10.2", "quote": "Subschemas of applicator keywords evaluate the instance completely independently such that the results of one such subschema MUST NOT impact the results of sibling subschemas." } ], - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - {"properties": {"foo": {}}} - ], - "additionalProperties": {"type": "boolean"} - }, - "tests": [ - { - "description": "properties defined in allOf are not examined", - "data": {"foo": 1, "bar": true}, - "valid": false - } - ] - }, - { - "description": "additionalProperties with null valued instance properties", - "specification": [ { "core":"10.3.2.3" } ], - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "additionalProperties": { - "type": "null" - } - }, - "tests": [ - { - "description": "allows null values", - "data": {"foo": null}, - "valid": true - } - ] - }, - { - "description": "additionalProperties with propertyNames", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": { - "maxLength": 5 - }, - "additionalProperties": { - "type": "number" - } - }, - "tests": [ - { - "description": "Valid against both keywords", - "data": { "apple": 4 }, - "valid": true - }, - { - "description": "Valid against propertyNames, but not additionalProperties", - "data": { "fig": 2, "pear": "available" }, - "valid": false - } - ] - }, - { - "description": "dependentSchemas with additionalProperties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": {"foo2": {}}, - "dependentSchemas": { - "foo" : {}, - "foo2": { - "properties": { - "bar": {} - } - } - }, - "additionalProperties": false - }, - "tests": [ - { - "description": "additionalProperties doesn't consider dependentSchemas", - "data": {"foo": ""}, - "valid": false - }, - { - "description": "additionalProperties can't see bar", - "data": {"bar": ""}, - "valid": false - }, - { - "description": "additionalProperties can't see bar even when foo2 is present", - "data": {"foo2": "", "bar": ""}, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/allOf.json b/jsonschema/testdata/draft2020-12/allOf.json deleted file mode 100644 index 9e87903f..00000000 --- a/jsonschema/testdata/draft2020-12/allOf.json +++ /dev/null @@ -1,312 +0,0 @@ -[ - { - "description": "allOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "properties": { - "bar": {"type": "integer"} - }, - "required": ["bar"] - }, - { - "properties": { - "foo": {"type": "string"} - }, - "required": ["foo"] - } - ] - }, - "tests": [ - { - "description": "allOf", - "data": {"foo": "baz", "bar": 2}, - "valid": true - }, - { - "description": "mismatch second", - "data": {"foo": "baz"}, - "valid": false - }, - { - "description": "mismatch first", - "data": {"bar": 2}, - "valid": false - }, - { - "description": "wrong type", - "data": {"foo": "baz", "bar": "quux"}, - "valid": false - } - ] - }, - { - "description": "allOf with base schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": {"bar": {"type": "integer"}}, - "required": ["bar"], - "allOf" : [ - { - "properties": { - "foo": {"type": "string"} - }, - "required": ["foo"] - }, - { - "properties": { - "baz": {"type": "null"} - }, - "required": ["baz"] - } - ] - }, - "tests": [ - { - "description": "valid", - "data": {"foo": "quux", "bar": 2, "baz": null}, - "valid": true - }, - { - "description": "mismatch base schema", - "data": {"foo": "quux", "baz": null}, - "valid": false - }, - { - "description": "mismatch first allOf", - "data": {"bar": 2, "baz": null}, - "valid": false - }, - { - "description": "mismatch second allOf", - "data": {"foo": "quux", "bar": 2}, - "valid": false - }, - { - "description": "mismatch both", - "data": {"bar": 2}, - "valid": false - } - ] - }, - { - "description": "allOf simple types", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - {"maximum": 30}, - {"minimum": 20} - ] - }, - "tests": [ - { - "description": "valid", - "data": 25, - "valid": true - }, - { - "description": "mismatch one", - "data": 35, - "valid": false - } - ] - }, - { - "description": "allOf with boolean schemas, all true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [true, true] - }, - "tests": [ - { - "description": "any value is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "allOf with boolean schemas, some false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [true, false] - }, - "tests": [ - { - "description": "any value is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "allOf with boolean schemas, all false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [false, false] - }, - "tests": [ - { - "description": "any value is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "allOf with one empty schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - {} - ] - }, - "tests": [ - { - "description": "any data is valid", - "data": 1, - "valid": true - } - ] - }, - { - "description": "allOf with two empty schemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - {}, - {} - ] - }, - "tests": [ - { - "description": "any data is valid", - "data": 1, - "valid": true - } - ] - }, - { - "description": "allOf with the first empty schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - {}, - { "type": "number" } - ] - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "allOf with the last empty schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { "type": "number" }, - {} - ] - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "nested allOf, to check validation semantics", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "allOf": [ - { - "type": "null" - } - ] - } - ] - }, - "tests": [ - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "anything non-null is invalid", - "data": 123, - "valid": false - } - ] - }, - { - "description": "allOf combined with anyOf, oneOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ { "multipleOf": 2 } ], - "anyOf": [ { "multipleOf": 3 } ], - "oneOf": [ { "multipleOf": 5 } ] - }, - "tests": [ - { - "description": "allOf: false, anyOf: false, oneOf: false", - "data": 1, - "valid": false - }, - { - "description": "allOf: false, anyOf: false, oneOf: true", - "data": 5, - "valid": false - }, - { - "description": "allOf: false, anyOf: true, oneOf: false", - "data": 3, - "valid": false - }, - { - "description": "allOf: false, anyOf: true, oneOf: true", - "data": 15, - "valid": false - }, - { - "description": "allOf: true, anyOf: false, oneOf: false", - "data": 2, - "valid": false - }, - { - "description": "allOf: true, anyOf: false, oneOf: true", - "data": 10, - "valid": false - }, - { - "description": "allOf: true, anyOf: true, oneOf: false", - "data": 6, - "valid": false - }, - { - "description": "allOf: true, anyOf: true, oneOf: true", - "data": 30, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/anchor.json b/jsonschema/testdata/draft2020-12/anchor.json deleted file mode 100644 index 99143fa1..00000000 --- a/jsonschema/testdata/draft2020-12/anchor.json +++ /dev/null @@ -1,120 +0,0 @@ -[ - { - "description": "Location-independent identifier", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "#foo", - "$defs": { - "A": { - "$anchor": "foo", - "type": "integer" - } - } - }, - "tests": [ - { - "data": 1, - "description": "match", - "valid": true - }, - { - "data": "a", - "description": "mismatch", - "valid": false - } - ] - }, - { - "description": "Location-independent identifier with absolute URI", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/draft2020-12/bar#foo", - "$defs": { - "A": { - "$id": "http://localhost:1234/draft2020-12/bar", - "$anchor": "foo", - "type": "integer" - } - } - }, - "tests": [ - { - "data": 1, - "description": "match", - "valid": true - }, - { - "data": "a", - "description": "mismatch", - "valid": false - } - ] - }, - { - "description": "Location-independent identifier with base URI change in subschema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/root", - "$ref": "http://localhost:1234/draft2020-12/nested.json#foo", - "$defs": { - "A": { - "$id": "nested.json", - "$defs": { - "B": { - "$anchor": "foo", - "type": "integer" - } - } - } - } - }, - "tests": [ - { - "data": 1, - "description": "match", - "valid": true - }, - { - "data": "a", - "description": "mismatch", - "valid": false - } - ] - }, - { - "description": "same $anchor with different base uri", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/foobar", - "$defs": { - "A": { - "$id": "child1", - "allOf": [ - { - "$id": "child2", - "$anchor": "my_anchor", - "type": "number" - }, - { - "$anchor": "my_anchor", - "type": "string" - } - ] - } - }, - "$ref": "child1#my_anchor" - }, - "tests": [ - { - "description": "$ref resolves to /$defs/A/allOf/1", - "data": "a", - "valid": true - }, - { - "description": "$ref does not resolve to /$defs/A/allOf/0", - "data": 1, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/anyOf.json b/jsonschema/testdata/draft2020-12/anyOf.json deleted file mode 100644 index 89b192db..00000000 --- a/jsonschema/testdata/draft2020-12/anyOf.json +++ /dev/null @@ -1,203 +0,0 @@ -[ - { - "description": "anyOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "anyOf": [ - { - "type": "integer" - }, - { - "minimum": 2 - } - ] - }, - "tests": [ - { - "description": "first anyOf valid", - "data": 1, - "valid": true - }, - { - "description": "second anyOf valid", - "data": 2.5, - "valid": true - }, - { - "description": "both anyOf valid", - "data": 3, - "valid": true - }, - { - "description": "neither anyOf valid", - "data": 1.5, - "valid": false - } - ] - }, - { - "description": "anyOf with base schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "string", - "anyOf" : [ - { - "maxLength": 2 - }, - { - "minLength": 4 - } - ] - }, - "tests": [ - { - "description": "mismatch base schema", - "data": 3, - "valid": false - }, - { - "description": "one anyOf valid", - "data": "foobar", - "valid": true - }, - { - "description": "both anyOf invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "anyOf with boolean schemas, all true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "anyOf": [true, true] - }, - "tests": [ - { - "description": "any value is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "anyOf with boolean schemas, some true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "anyOf": [true, false] - }, - "tests": [ - { - "description": "any value is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "anyOf with boolean schemas, all false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "anyOf": [false, false] - }, - "tests": [ - { - "description": "any value is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "anyOf complex types", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "anyOf": [ - { - "properties": { - "bar": {"type": "integer"} - }, - "required": ["bar"] - }, - { - "properties": { - "foo": {"type": "string"} - }, - "required": ["foo"] - } - ] - }, - "tests": [ - { - "description": "first anyOf valid (complex)", - "data": {"bar": 2}, - "valid": true - }, - { - "description": "second anyOf valid (complex)", - "data": {"foo": "baz"}, - "valid": true - }, - { - "description": "both anyOf valid (complex)", - "data": {"foo": "baz", "bar": 2}, - "valid": true - }, - { - "description": "neither anyOf valid (complex)", - "data": {"foo": 2, "bar": "quux"}, - "valid": false - } - ] - }, - { - "description": "anyOf with one empty schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "anyOf": [ - { "type": "number" }, - {} - ] - }, - "tests": [ - { - "description": "string is valid", - "data": "foo", - "valid": true - }, - { - "description": "number is valid", - "data": 123, - "valid": true - } - ] - }, - { - "description": "nested anyOf, to check validation semantics", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "anyOf": [ - { - "anyOf": [ - { - "type": "null" - } - ] - } - ] - }, - "tests": [ - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "anything non-null is invalid", - "data": 123, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/boolean_schema.json b/jsonschema/testdata/draft2020-12/boolean_schema.json deleted file mode 100644 index 6d40f23f..00000000 --- a/jsonschema/testdata/draft2020-12/boolean_schema.json +++ /dev/null @@ -1,104 +0,0 @@ -[ - { - "description": "boolean schema 'true'", - "schema": true, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "string is valid", - "data": "foo", - "valid": true - }, - { - "description": "boolean true is valid", - "data": true, - "valid": true - }, - { - "description": "boolean false is valid", - "data": false, - "valid": true - }, - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "object is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - }, - { - "description": "array is valid", - "data": ["foo"], - "valid": true - }, - { - "description": "empty array is valid", - "data": [], - "valid": true - } - ] - }, - { - "description": "boolean schema 'false'", - "schema": false, - "tests": [ - { - "description": "number is invalid", - "data": 1, - "valid": false - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - }, - { - "description": "boolean true is invalid", - "data": true, - "valid": false - }, - { - "description": "boolean false is invalid", - "data": false, - "valid": false - }, - { - "description": "null is invalid", - "data": null, - "valid": false - }, - { - "description": "object is invalid", - "data": {"foo": "bar"}, - "valid": false - }, - { - "description": "empty object is invalid", - "data": {}, - "valid": false - }, - { - "description": "array is invalid", - "data": ["foo"], - "valid": false - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/const.json b/jsonschema/testdata/draft2020-12/const.json deleted file mode 100644 index 50be86a0..00000000 --- a/jsonschema/testdata/draft2020-12/const.json +++ /dev/null @@ -1,387 +0,0 @@ -[ - { - "description": "const validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": 2 - }, - "tests": [ - { - "description": "same value is valid", - "data": 2, - "valid": true - }, - { - "description": "another value is invalid", - "data": 5, - "valid": false - }, - { - "description": "another type is invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "const with object", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": {"foo": "bar", "baz": "bax"} - }, - "tests": [ - { - "description": "same object is valid", - "data": {"foo": "bar", "baz": "bax"}, - "valid": true - }, - { - "description": "same object with different property order is valid", - "data": {"baz": "bax", "foo": "bar"}, - "valid": true - }, - { - "description": "another object is invalid", - "data": {"foo": "bar"}, - "valid": false - }, - { - "description": "another type is invalid", - "data": [1, 2], - "valid": false - } - ] - }, - { - "description": "const with array", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": [{ "foo": "bar" }] - }, - "tests": [ - { - "description": "same array is valid", - "data": [{"foo": "bar"}], - "valid": true - }, - { - "description": "another array item is invalid", - "data": [2], - "valid": false - }, - { - "description": "array with additional items is invalid", - "data": [1, 2, 3], - "valid": false - } - ] - }, - { - "description": "const with null", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": null - }, - "tests": [ - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "not null is invalid", - "data": 0, - "valid": false - } - ] - }, - { - "description": "const with false does not match 0", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": false - }, - "tests": [ - { - "description": "false is valid", - "data": false, - "valid": true - }, - { - "description": "integer zero is invalid", - "data": 0, - "valid": false - }, - { - "description": "float zero is invalid", - "data": 0.0, - "valid": false - } - ] - }, - { - "description": "const with true does not match 1", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": true - }, - "tests": [ - { - "description": "true is valid", - "data": true, - "valid": true - }, - { - "description": "integer one is invalid", - "data": 1, - "valid": false - }, - { - "description": "float one is invalid", - "data": 1.0, - "valid": false - } - ] - }, - { - "description": "const with [false] does not match [0]", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": [false] - }, - "tests": [ - { - "description": "[false] is valid", - "data": [false], - "valid": true - }, - { - "description": "[0] is invalid", - "data": [0], - "valid": false - }, - { - "description": "[0.0] is invalid", - "data": [0.0], - "valid": false - } - ] - }, - { - "description": "const with [true] does not match [1]", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": [true] - }, - "tests": [ - { - "description": "[true] is valid", - "data": [true], - "valid": true - }, - { - "description": "[1] is invalid", - "data": [1], - "valid": false - }, - { - "description": "[1.0] is invalid", - "data": [1.0], - "valid": false - } - ] - }, - { - "description": "const with {\"a\": false} does not match {\"a\": 0}", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": {"a": false} - }, - "tests": [ - { - "description": "{\"a\": false} is valid", - "data": {"a": false}, - "valid": true - }, - { - "description": "{\"a\": 0} is invalid", - "data": {"a": 0}, - "valid": false - }, - { - "description": "{\"a\": 0.0} is invalid", - "data": {"a": 0.0}, - "valid": false - } - ] - }, - { - "description": "const with {\"a\": true} does not match {\"a\": 1}", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": {"a": true} - }, - "tests": [ - { - "description": "{\"a\": true} is valid", - "data": {"a": true}, - "valid": true - }, - { - "description": "{\"a\": 1} is invalid", - "data": {"a": 1}, - "valid": false - }, - { - "description": "{\"a\": 1.0} is invalid", - "data": {"a": 1.0}, - "valid": false - } - ] - }, - { - "description": "const with 0 does not match other zero-like types", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": 0 - }, - "tests": [ - { - "description": "false is invalid", - "data": false, - "valid": false - }, - { - "description": "integer zero is valid", - "data": 0, - "valid": true - }, - { - "description": "float zero is valid", - "data": 0.0, - "valid": true - }, - { - "description": "empty object is invalid", - "data": {}, - "valid": false - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - }, - { - "description": "empty string is invalid", - "data": "", - "valid": false - } - ] - }, - { - "description": "const with 1 does not match true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": 1 - }, - "tests": [ - { - "description": "true is invalid", - "data": true, - "valid": false - }, - { - "description": "integer one is valid", - "data": 1, - "valid": true - }, - { - "description": "float one is valid", - "data": 1.0, - "valid": true - } - ] - }, - { - "description": "const with -2.0 matches integer and float types", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": -2.0 - }, - "tests": [ - { - "description": "integer -2 is valid", - "data": -2, - "valid": true - }, - { - "description": "integer 2 is invalid", - "data": 2, - "valid": false - }, - { - "description": "float -2.0 is valid", - "data": -2.0, - "valid": true - }, - { - "description": "float 2.0 is invalid", - "data": 2.0, - "valid": false - }, - { - "description": "float -2.00001 is invalid", - "data": -2.00001, - "valid": false - } - ] - }, - { - "description": "float and integers are equal up to 64-bit representation limits", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": 9007199254740992 - }, - "tests": [ - { - "description": "integer is valid", - "data": 9007199254740992, - "valid": true - }, - { - "description": "integer minus one is invalid", - "data": 9007199254740991, - "valid": false - }, - { - "description": "float is valid", - "data": 9007199254740992.0, - "valid": true - }, - { - "description": "float minus one is invalid", - "data": 9007199254740991.0, - "valid": false - } - ] - }, - { - "description": "nul characters in strings", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "const": "hello\u0000there" - }, - "tests": [ - { - "description": "match string with nul", - "data": "hello\u0000there", - "valid": true - }, - { - "description": "do not match string lacking nul", - "data": "hellothere", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/contains.json b/jsonschema/testdata/draft2020-12/contains.json deleted file mode 100644 index 08a00a75..00000000 --- a/jsonschema/testdata/draft2020-12/contains.json +++ /dev/null @@ -1,176 +0,0 @@ -[ - { - "description": "contains keyword validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"minimum": 5} - }, - "tests": [ - { - "description": "array with item matching schema (5) is valid", - "data": [3, 4, 5], - "valid": true - }, - { - "description": "array with item matching schema (6) is valid", - "data": [3, 4, 6], - "valid": true - }, - { - "description": "array with two items matching schema (5, 6) is valid", - "data": [3, 4, 5, 6], - "valid": true - }, - { - "description": "array without items matching schema is invalid", - "data": [2, 3, 4], - "valid": false - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - }, - { - "description": "not array is valid", - "data": {}, - "valid": true - } - ] - }, - { - "description": "contains keyword with const keyword", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": { "const": 5 } - }, - "tests": [ - { - "description": "array with item 5 is valid", - "data": [3, 4, 5], - "valid": true - }, - { - "description": "array with two items 5 is valid", - "data": [3, 4, 5, 5], - "valid": true - }, - { - "description": "array without item 5 is invalid", - "data": [1, 2, 3, 4], - "valid": false - } - ] - }, - { - "description": "contains keyword with boolean schema true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": true - }, - "tests": [ - { - "description": "any non-empty array is valid", - "data": ["foo"], - "valid": true - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - } - ] - }, - { - "description": "contains keyword with boolean schema false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": false - }, - "tests": [ - { - "description": "any non-empty array is invalid", - "data": ["foo"], - "valid": false - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - }, - { - "description": "non-arrays are valid", - "data": "contains does not apply to strings", - "valid": true - } - ] - }, - { - "description": "items + contains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "items": { "multipleOf": 2 }, - "contains": { "multipleOf": 3 } - }, - "tests": [ - { - "description": "matches items, does not match contains", - "data": [ 2, 4, 8 ], - "valid": false - }, - { - "description": "does not match items, matches contains", - "data": [ 3, 6, 9 ], - "valid": false - }, - { - "description": "matches both items and contains", - "data": [ 6, 12 ], - "valid": true - }, - { - "description": "matches neither items nor contains", - "data": [ 1, 5 ], - "valid": false - } - ] - }, - { - "description": "contains with false if subschema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": { - "if": false, - "else": true - } - }, - "tests": [ - { - "description": "any non-empty array is valid", - "data": ["foo"], - "valid": true - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - } - ] - }, - { - "description": "contains with null instance elements", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": { - "type": "null" - } - }, - "tests": [ - { - "description": "allows null items", - "data": [ null ], - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/default.json b/jsonschema/testdata/draft2020-12/default.json deleted file mode 100644 index ceb3ae27..00000000 --- a/jsonschema/testdata/draft2020-12/default.json +++ /dev/null @@ -1,82 +0,0 @@ -[ - { - "description": "invalid type for default", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": { - "type": "integer", - "default": [] - } - } - }, - "tests": [ - { - "description": "valid when property is specified", - "data": {"foo": 13}, - "valid": true - }, - { - "description": "still valid when the invalid default is used", - "data": {}, - "valid": true - } - ] - }, - { - "description": "invalid string value for default", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "bar": { - "type": "string", - "minLength": 4, - "default": "bad" - } - } - }, - "tests": [ - { - "description": "valid when property is specified", - "data": {"bar": "good"}, - "valid": true - }, - { - "description": "still valid when the invalid default is used", - "data": {}, - "valid": true - } - ] - }, - { - "description": "the default keyword does not do anything if the property is missing", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "alpha": { - "type": "number", - "maximum": 3, - "default": 5 - } - } - }, - "tests": [ - { - "description": "an explicit property value is checked against maximum (passing)", - "data": { "alpha": 1 }, - "valid": true - }, - { - "description": "an explicit property value is checked against maximum (failing)", - "data": { "alpha": 5 }, - "valid": false - }, - { - "description": "missing properties are not filled in with the default", - "data": {}, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/defs.json b/jsonschema/testdata/draft2020-12/defs.json deleted file mode 100644 index da2a503b..00000000 --- a/jsonschema/testdata/draft2020-12/defs.json +++ /dev/null @@ -1,21 +0,0 @@ -[ - { - "description": "validate definition against metaschema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "https://json-schema.org/draft/2020-12/schema" - }, - "tests": [ - { - "description": "valid definition schema", - "data": {"$defs": {"foo": {"type": "integer"}}}, - "valid": true - }, - { - "description": "invalid definition schema", - "data": {"$defs": {"foo": {"type": 1}}}, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/dependentRequired.json b/jsonschema/testdata/draft2020-12/dependentRequired.json deleted file mode 100644 index 2baa38e9..00000000 --- a/jsonschema/testdata/draft2020-12/dependentRequired.json +++ /dev/null @@ -1,152 +0,0 @@ -[ - { - "description": "single dependency", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "dependentRequired": {"bar": ["foo"]} - }, - "tests": [ - { - "description": "neither", - "data": {}, - "valid": true - }, - { - "description": "nondependant", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "with dependency", - "data": {"foo": 1, "bar": 2}, - "valid": true - }, - { - "description": "missing dependency", - "data": {"bar": 2}, - "valid": false - }, - { - "description": "ignores arrays", - "data": ["bar"], - "valid": true - }, - { - "description": "ignores strings", - "data": "foobar", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": "empty dependents", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "dependentRequired": {"bar": []} - }, - "tests": [ - { - "description": "empty object", - "data": {}, - "valid": true - }, - { - "description": "object with one property", - "data": {"bar": 2}, - "valid": true - }, - { - "description": "non-object is valid", - "data": 1, - "valid": true - } - ] - }, - { - "description": "multiple dependents required", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "dependentRequired": {"quux": ["foo", "bar"]} - }, - "tests": [ - { - "description": "neither", - "data": {}, - "valid": true - }, - { - "description": "nondependants", - "data": {"foo": 1, "bar": 2}, - "valid": true - }, - { - "description": "with dependencies", - "data": {"foo": 1, "bar": 2, "quux": 3}, - "valid": true - }, - { - "description": "missing dependency", - "data": {"foo": 1, "quux": 2}, - "valid": false - }, - { - "description": "missing other dependency", - "data": {"bar": 1, "quux": 2}, - "valid": false - }, - { - "description": "missing both dependencies", - "data": {"quux": 1}, - "valid": false - } - ] - }, - { - "description": "dependencies with escaped characters", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "dependentRequired": { - "foo\nbar": ["foo\rbar"], - "foo\"bar": ["foo'bar"] - } - }, - "tests": [ - { - "description": "CRLF", - "data": { - "foo\nbar": 1, - "foo\rbar": 2 - }, - "valid": true - }, - { - "description": "quoted quotes", - "data": { - "foo'bar": 1, - "foo\"bar": 2 - }, - "valid": true - }, - { - "description": "CRLF missing dependent", - "data": { - "foo\nbar": 1, - "foo": 2 - }, - "valid": false - }, - { - "description": "quoted quotes missing dependent", - "data": { - "foo\"bar": 2 - }, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/dependentSchemas.json b/jsonschema/testdata/draft2020-12/dependentSchemas.json deleted file mode 100644 index 1c5f0574..00000000 --- a/jsonschema/testdata/draft2020-12/dependentSchemas.json +++ /dev/null @@ -1,171 +0,0 @@ -[ - { - "description": "single dependency", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "dependentSchemas": { - "bar": { - "properties": { - "foo": {"type": "integer"}, - "bar": {"type": "integer"} - } - } - } - }, - "tests": [ - { - "description": "valid", - "data": {"foo": 1, "bar": 2}, - "valid": true - }, - { - "description": "no dependency", - "data": {"foo": "quux"}, - "valid": true - }, - { - "description": "wrong type", - "data": {"foo": "quux", "bar": 2}, - "valid": false - }, - { - "description": "wrong type other", - "data": {"foo": 2, "bar": "quux"}, - "valid": false - }, - { - "description": "wrong type both", - "data": {"foo": "quux", "bar": "quux"}, - "valid": false - }, - { - "description": "ignores arrays", - "data": ["bar"], - "valid": true - }, - { - "description": "ignores strings", - "data": "foobar", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": "boolean subschemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "dependentSchemas": { - "foo": true, - "bar": false - } - }, - "tests": [ - { - "description": "object with property having schema true is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "object with property having schema false is invalid", - "data": {"bar": 2}, - "valid": false - }, - { - "description": "object with both properties is invalid", - "data": {"foo": 1, "bar": 2}, - "valid": false - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - } - ] - }, - { - "description": "dependencies with escaped characters", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "dependentSchemas": { - "foo\tbar": {"minProperties": 4}, - "foo'bar": {"required": ["foo\"bar"]} - } - }, - "tests": [ - { - "description": "quoted tab", - "data": { - "foo\tbar": 1, - "a": 2, - "b": 3, - "c": 4 - }, - "valid": true - }, - { - "description": "quoted quote", - "data": { - "foo'bar": {"foo\"bar": 1} - }, - "valid": false - }, - { - "description": "quoted tab invalid under dependent schema", - "data": { - "foo\tbar": 1, - "a": 2 - }, - "valid": false - }, - { - "description": "quoted quote invalid under dependent schema", - "data": {"foo'bar": 1}, - "valid": false - } - ] - }, - { - "description": "dependent subschema incompatible with root", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {} - }, - "dependentSchemas": { - "foo": { - "properties": { - "bar": {} - }, - "additionalProperties": false - } - } - }, - "tests": [ - { - "description": "matches root", - "data": {"foo": 1}, - "valid": false - }, - { - "description": "matches dependency", - "data": {"bar": 1}, - "valid": true - }, - { - "description": "matches both", - "data": {"foo": 1, "bar": 2}, - "valid": false - }, - { - "description": "no dependency", - "data": {"baz": 1}, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/dynamicRef.json b/jsonschema/testdata/draft2020-12/dynamicRef.json deleted file mode 100644 index ffa211ba..00000000 --- a/jsonschema/testdata/draft2020-12/dynamicRef.json +++ /dev/null @@ -1,815 +0,0 @@ -[ - { - "description": "A $dynamicRef to a $dynamicAnchor in the same schema resource behaves like a normal $ref to an $anchor", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamicRef-dynamicAnchor-same-schema/root", - "type": "array", - "items": { "$dynamicRef": "#items" }, - "$defs": { - "foo": { - "$dynamicAnchor": "items", - "type": "string" - } - } - }, - "tests": [ - { - "description": "An array of strings is valid", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "An array containing non-strings is invalid", - "data": ["foo", 42], - "valid": false - } - ] - }, - { - "description": "A $dynamicRef to an $anchor in the same schema resource behaves like a normal $ref to an $anchor", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamicRef-anchor-same-schema/root", - "type": "array", - "items": { "$dynamicRef": "#items" }, - "$defs": { - "foo": { - "$anchor": "items", - "type": "string" - } - } - }, - "tests": [ - { - "description": "An array of strings is valid", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "An array containing non-strings is invalid", - "data": ["foo", 42], - "valid": false - } - ] - }, - { - "description": "A $ref to a $dynamicAnchor in the same schema resource behaves like a normal $ref to an $anchor", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/ref-dynamicAnchor-same-schema/root", - "type": "array", - "items": { "$ref": "#items" }, - "$defs": { - "foo": { - "$dynamicAnchor": "items", - "type": "string" - } - } - }, - "tests": [ - { - "description": "An array of strings is valid", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "An array containing non-strings is invalid", - "data": ["foo", 42], - "valid": false - } - ] - }, - { - "description": "A $dynamicRef resolves to the first $dynamicAnchor still in scope that is encountered when the schema is evaluated", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/typical-dynamic-resolution/root", - "$ref": "list", - "$defs": { - "foo": { - "$dynamicAnchor": "items", - "type": "string" - }, - "list": { - "$id": "list", - "type": "array", - "items": { "$dynamicRef": "#items" }, - "$defs": { - "items": { - "$comment": "This is only needed to satisfy the bookending requirement", - "$dynamicAnchor": "items" - } - } - } - } - }, - "tests": [ - { - "description": "An array of strings is valid", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "An array containing non-strings is invalid", - "data": ["foo", 42], - "valid": false - } - ] - }, - { - "description": "A $dynamicRef without anchor in fragment behaves identical to $ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamicRef-without-anchor/root", - "$ref": "list", - "$defs": { - "foo": { - "$dynamicAnchor": "items", - "type": "string" - }, - "list": { - "$id": "list", - "type": "array", - "items": { "$dynamicRef": "#/$defs/items" }, - "$defs": { - "items": { - "$comment": "This is only needed to satisfy the bookending requirement", - "$dynamicAnchor": "items", - "type": "number" - } - } - } - } - }, - "tests": [ - { - "description": "An array of strings is invalid", - "data": ["foo", "bar"], - "valid": false - }, - { - "description": "An array of numbers is valid", - "data": [24, 42], - "valid": true - } - ] - }, - { - "description": "A $dynamicRef with intermediate scopes that don't include a matching $dynamicAnchor does not affect dynamic scope resolution", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamic-resolution-with-intermediate-scopes/root", - "$ref": "intermediate-scope", - "$defs": { - "foo": { - "$dynamicAnchor": "items", - "type": "string" - }, - "intermediate-scope": { - "$id": "intermediate-scope", - "$ref": "list" - }, - "list": { - "$id": "list", - "type": "array", - "items": { "$dynamicRef": "#items" }, - "$defs": { - "items": { - "$comment": "This is only needed to satisfy the bookending requirement", - "$dynamicAnchor": "items" - } - } - } - } - }, - "tests": [ - { - "description": "An array of strings is valid", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "An array containing non-strings is invalid", - "data": ["foo", 42], - "valid": false - } - ] - }, - { - "description": "An $anchor with the same name as a $dynamicAnchor is not used for dynamic scope resolution", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamic-resolution-ignores-anchors/root", - "$ref": "list", - "$defs": { - "foo": { - "$anchor": "items", - "type": "string" - }, - "list": { - "$id": "list", - "type": "array", - "items": { "$dynamicRef": "#items" }, - "$defs": { - "items": { - "$comment": "This is only needed to satisfy the bookending requirement", - "$dynamicAnchor": "items" - } - } - } - } - }, - "tests": [ - { - "description": "Any array is valid", - "data": ["foo", 42], - "valid": true - } - ] - }, - { - "description": "A $dynamicRef without a matching $dynamicAnchor in the same schema resource behaves like a normal $ref to $anchor", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamic-resolution-without-bookend/root", - "$ref": "list", - "$defs": { - "foo": { - "$dynamicAnchor": "items", - "type": "string" - }, - "list": { - "$id": "list", - "type": "array", - "items": { "$dynamicRef": "#items" }, - "$defs": { - "items": { - "$comment": "This is only needed to give the reference somewhere to resolve to when it behaves like $ref", - "$anchor": "items" - } - } - } - } - }, - "tests": [ - { - "description": "Any array is valid", - "data": ["foo", 42], - "valid": true - } - ] - }, - { - "description": "A $dynamicRef with a non-matching $dynamicAnchor in the same schema resource behaves like a normal $ref to $anchor", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/unmatched-dynamic-anchor/root", - "$ref": "list", - "$defs": { - "foo": { - "$dynamicAnchor": "items", - "type": "string" - }, - "list": { - "$id": "list", - "type": "array", - "items": { "$dynamicRef": "#items" }, - "$defs": { - "items": { - "$comment": "This is only needed to give the reference somewhere to resolve to when it behaves like $ref", - "$anchor": "items", - "$dynamicAnchor": "foo" - } - } - } - } - }, - "tests": [ - { - "description": "Any array is valid", - "data": ["foo", 42], - "valid": true - } - ] - }, - { - "description": "A $dynamicRef that initially resolves to a schema with a matching $dynamicAnchor resolves to the first $dynamicAnchor in the dynamic scope", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/relative-dynamic-reference/root", - "$dynamicAnchor": "meta", - "type": "object", - "properties": { - "foo": { "const": "pass" } - }, - "$ref": "extended", - "$defs": { - "extended": { - "$id": "extended", - "$dynamicAnchor": "meta", - "type": "object", - "properties": { - "bar": { "$ref": "bar" } - } - }, - "bar": { - "$id": "bar", - "type": "object", - "properties": { - "baz": { "$dynamicRef": "extended#meta" } - } - } - } - }, - "tests": [ - { - "description": "The recursive part is valid against the root", - "data": { - "foo": "pass", - "bar": { - "baz": { "foo": "pass" } - } - }, - "valid": true - }, - { - "description": "The recursive part is not valid against the root", - "data": { - "foo": "pass", - "bar": { - "baz": { "foo": "fail" } - } - }, - "valid": false - } - ] - }, - { - "description": "A $dynamicRef that initially resolves to a schema without a matching $dynamicAnchor behaves like a normal $ref to $anchor", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/relative-dynamic-reference-without-bookend/root", - "$dynamicAnchor": "meta", - "type": "object", - "properties": { - "foo": { "const": "pass" } - }, - "$ref": "extended", - "$defs": { - "extended": { - "$id": "extended", - "$anchor": "meta", - "type": "object", - "properties": { - "bar": { "$ref": "bar" } - } - }, - "bar": { - "$id": "bar", - "type": "object", - "properties": { - "baz": { "$dynamicRef": "extended#meta" } - } - } - } - }, - "tests": [ - { - "description": "The recursive part doesn't need to validate against the root", - "data": { - "foo": "pass", - "bar": { - "baz": { "foo": "fail" } - } - }, - "valid": true - } - ] - }, - { - "description": "multiple dynamic paths to the $dynamicRef keyword", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamic-ref-with-multiple-paths/main", - "if": { - "properties": { - "kindOfList": { "const": "numbers" } - }, - "required": ["kindOfList"] - }, - "then": { "$ref": "numberList" }, - "else": { "$ref": "stringList" }, - - "$defs": { - "genericList": { - "$id": "genericList", - "properties": { - "list": { - "items": { "$dynamicRef": "#itemType" } - } - }, - "$defs": { - "defaultItemType": { - "$comment": "Only needed to satisfy bookending requirement", - "$dynamicAnchor": "itemType" - } - } - }, - "numberList": { - "$id": "numberList", - "$defs": { - "itemType": { - "$dynamicAnchor": "itemType", - "type": "number" - } - }, - "$ref": "genericList" - }, - "stringList": { - "$id": "stringList", - "$defs": { - "itemType": { - "$dynamicAnchor": "itemType", - "type": "string" - } - }, - "$ref": "genericList" - } - } - }, - "tests": [ - { - "description": "number list with number values", - "data": { - "kindOfList": "numbers", - "list": [1.1] - }, - "valid": true - }, - { - "description": "number list with string values", - "data": { - "kindOfList": "numbers", - "list": ["foo"] - }, - "valid": false - }, - { - "description": "string list with number values", - "data": { - "kindOfList": "strings", - "list": [1.1] - }, - "valid": false - }, - { - "description": "string list with string values", - "data": { - "kindOfList": "strings", - "list": ["foo"] - }, - "valid": true - } - ] - }, - { - "description": "after leaving a dynamic scope, it is not used by a $dynamicRef", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamic-ref-leaving-dynamic-scope/main", - "if": { - "$id": "first_scope", - "$defs": { - "thingy": { - "$comment": "this is first_scope#thingy", - "$dynamicAnchor": "thingy", - "type": "number" - } - } - }, - "then": { - "$id": "second_scope", - "$ref": "start", - "$defs": { - "thingy": { - "$comment": "this is second_scope#thingy, the final destination of the $dynamicRef", - "$dynamicAnchor": "thingy", - "type": "null" - } - } - }, - "$defs": { - "start": { - "$comment": "this is the landing spot from $ref", - "$id": "start", - "$dynamicRef": "inner_scope#thingy" - }, - "thingy": { - "$comment": "this is the first stop for the $dynamicRef", - "$id": "inner_scope", - "$dynamicAnchor": "thingy", - "type": "string" - } - } - }, - "tests": [ - { - "description": "string matches /$defs/thingy, but the $dynamicRef does not stop here", - "data": "a string", - "valid": false - }, - { - "description": "first_scope is not in dynamic scope for the $dynamicRef", - "data": 42, - "valid": false - }, - { - "description": "/then/$defs/thingy is the final stop for the $dynamicRef", - "data": null, - "valid": true - } - ] - }, - { - "description": "strict-tree schema, guards against misspelled properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/strict-tree.json", - "$dynamicAnchor": "node", - - "$ref": "tree.json", - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "instance with misspelled field", - "data": { - "children": [{ - "daat": 1 - }] - }, - "valid": false - }, - { - "description": "instance with correct field", - "data": { - "children": [{ - "data": 1 - }] - }, - "valid": true - } - ] - }, - { - "description": "tests for implementation dynamic anchor and reference link", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/strict-extendible.json", - "$ref": "extendible-dynamic-ref.json", - "$defs": { - "elements": { - "$dynamicAnchor": "elements", - "properties": { - "a": true - }, - "required": ["a"], - "additionalProperties": false - } - } - }, - "tests": [ - { - "description": "incorrect parent schema", - "data": { - "a": true - }, - "valid": false - }, - { - "description": "incorrect extended schema", - "data": { - "elements": [ - { "b": 1 } - ] - }, - "valid": false - }, - { - "description": "correct extended schema", - "data": { - "elements": [ - { "a": 1 } - ] - }, - "valid": true - } - ] - }, - { - "description": "$ref and $dynamicAnchor are independent of order - $defs first", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/strict-extendible-allof-defs-first.json", - "allOf": [ - { - "$ref": "extendible-dynamic-ref.json" - }, - { - "$defs": { - "elements": { - "$dynamicAnchor": "elements", - "properties": { - "a": true - }, - "required": ["a"], - "additionalProperties": false - } - } - } - ] - }, - "tests": [ - { - "description": "incorrect parent schema", - "data": { - "a": true - }, - "valid": false - }, - { - "description": "incorrect extended schema", - "data": { - "elements": [ - { "b": 1 } - ] - }, - "valid": false - }, - { - "description": "correct extended schema", - "data": { - "elements": [ - { "a": 1 } - ] - }, - "valid": true - } - ] - }, - { - "description": "$ref and $dynamicAnchor are independent of order - $ref first", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/strict-extendible-allof-ref-first.json", - "allOf": [ - { - "$defs": { - "elements": { - "$dynamicAnchor": "elements", - "properties": { - "a": true - }, - "required": ["a"], - "additionalProperties": false - } - } - }, - { - "$ref": "extendible-dynamic-ref.json" - } - ] - }, - "tests": [ - { - "description": "incorrect parent schema", - "data": { - "a": true - }, - "valid": false - }, - { - "description": "incorrect extended schema", - "data": { - "elements": [ - { "b": 1 } - ] - }, - "valid": false - }, - { - "description": "correct extended schema", - "data": { - "elements": [ - { "a": 1 } - ] - }, - "valid": true - } - ] - }, - { - "description": "$ref to $dynamicRef finds detached $dynamicAnchor", - "schema": { - "$ref": "http://localhost:1234/draft2020-12/detached-dynamicref.json#/$defs/foo" - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "non-number is invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "$dynamicRef points to a boolean schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "true": true, - "false": false - }, - "properties": { - "true": { - "$dynamicRef": "#/$defs/true" - }, - "false": { - "$dynamicRef": "#/$defs/false" - } - } - }, - "tests": [ - { - "description": "follow $dynamicRef to a true schema", - "data": { "true": 1 }, - "valid": true - }, - { - "description": "follow $dynamicRef to a false schema", - "data": { "false": 1 }, - "valid": false - } - ] - }, - { - "description": "$dynamicRef skips over intermediate resources - direct reference", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://test.json-schema.org/dynamic-ref-skips-intermediate-resource/main", - "type": "object", - "properties": { - "bar-item": { - "$ref": "item" - } - }, - "$defs": { - "bar": { - "$id": "bar", - "type": "array", - "items": { - "$ref": "item" - }, - "$defs": { - "item": { - "$id": "item", - "type": "object", - "properties": { - "content": { - "$dynamicRef": "#content" - } - }, - "$defs": { - "defaultContent": { - "$dynamicAnchor": "content", - "type": "integer" - } - } - }, - "content": { - "$dynamicAnchor": "content", - "type": "string" - } - } - } - } - }, - "tests": [ - { - "description": "integer property passes", - "data": { "bar-item": { "content": 42 } }, - "valid": true - }, - { - "description": "string property fails", - "data": { "bar-item": { "content": "value" } }, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/enum.json b/jsonschema/testdata/draft2020-12/enum.json deleted file mode 100644 index c8f35eac..00000000 --- a/jsonschema/testdata/draft2020-12/enum.json +++ /dev/null @@ -1,358 +0,0 @@ -[ - { - "description": "simple enum validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [1, 2, 3] - }, - "tests": [ - { - "description": "one of the enum is valid", - "data": 1, - "valid": true - }, - { - "description": "something else is invalid", - "data": 4, - "valid": false - } - ] - }, - { - "description": "heterogeneous enum validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [6, "foo", [], true, {"foo": 12}] - }, - "tests": [ - { - "description": "one of the enum is valid", - "data": [], - "valid": true - }, - { - "description": "something else is invalid", - "data": null, - "valid": false - }, - { - "description": "objects are deep compared", - "data": {"foo": false}, - "valid": false - }, - { - "description": "valid object matches", - "data": {"foo": 12}, - "valid": true - }, - { - "description": "extra properties in object is invalid", - "data": {"foo": 12, "boo": 42}, - "valid": false - } - ] - }, - { - "description": "heterogeneous enum-with-null validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [6, null] - }, - "tests": [ - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "number is valid", - "data": 6, - "valid": true - }, - { - "description": "something else is invalid", - "data": "test", - "valid": false - } - ] - }, - { - "description": "enums in properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type":"object", - "properties": { - "foo": {"enum":["foo"]}, - "bar": {"enum":["bar"]} - }, - "required": ["bar"] - }, - "tests": [ - { - "description": "both properties are valid", - "data": {"foo":"foo", "bar":"bar"}, - "valid": true - }, - { - "description": "wrong foo value", - "data": {"foo":"foot", "bar":"bar"}, - "valid": false - }, - { - "description": "wrong bar value", - "data": {"foo":"foo", "bar":"bart"}, - "valid": false - }, - { - "description": "missing optional property is valid", - "data": {"bar":"bar"}, - "valid": true - }, - { - "description": "missing required property is invalid", - "data": {"foo":"foo"}, - "valid": false - }, - { - "description": "missing all properties is invalid", - "data": {}, - "valid": false - } - ] - }, - { - "description": "enum with escaped characters", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": ["foo\nbar", "foo\rbar"] - }, - "tests": [ - { - "description": "member 1 is valid", - "data": "foo\nbar", - "valid": true - }, - { - "description": "member 2 is valid", - "data": "foo\rbar", - "valid": true - }, - { - "description": "another string is invalid", - "data": "abc", - "valid": false - } - ] - }, - { - "description": "enum with false does not match 0", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [false] - }, - "tests": [ - { - "description": "false is valid", - "data": false, - "valid": true - }, - { - "description": "integer zero is invalid", - "data": 0, - "valid": false - }, - { - "description": "float zero is invalid", - "data": 0.0, - "valid": false - } - ] - }, - { - "description": "enum with [false] does not match [0]", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [[false]] - }, - "tests": [ - { - "description": "[false] is valid", - "data": [false], - "valid": true - }, - { - "description": "[0] is invalid", - "data": [0], - "valid": false - }, - { - "description": "[0.0] is invalid", - "data": [0.0], - "valid": false - } - ] - }, - { - "description": "enum with true does not match 1", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [true] - }, - "tests": [ - { - "description": "true is valid", - "data": true, - "valid": true - }, - { - "description": "integer one is invalid", - "data": 1, - "valid": false - }, - { - "description": "float one is invalid", - "data": 1.0, - "valid": false - } - ] - }, - { - "description": "enum with [true] does not match [1]", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [[true]] - }, - "tests": [ - { - "description": "[true] is valid", - "data": [true], - "valid": true - }, - { - "description": "[1] is invalid", - "data": [1], - "valid": false - }, - { - "description": "[1.0] is invalid", - "data": [1.0], - "valid": false - } - ] - }, - { - "description": "enum with 0 does not match false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [0] - }, - "tests": [ - { - "description": "false is invalid", - "data": false, - "valid": false - }, - { - "description": "integer zero is valid", - "data": 0, - "valid": true - }, - { - "description": "float zero is valid", - "data": 0.0, - "valid": true - } - ] - }, - { - "description": "enum with [0] does not match [false]", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [[0]] - }, - "tests": [ - { - "description": "[false] is invalid", - "data": [false], - "valid": false - }, - { - "description": "[0] is valid", - "data": [0], - "valid": true - }, - { - "description": "[0.0] is valid", - "data": [0.0], - "valid": true - } - ] - }, - { - "description": "enum with 1 does not match true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [1] - }, - "tests": [ - { - "description": "true is invalid", - "data": true, - "valid": false - }, - { - "description": "integer one is valid", - "data": 1, - "valid": true - }, - { - "description": "float one is valid", - "data": 1.0, - "valid": true - } - ] - }, - { - "description": "enum with [1] does not match [true]", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [[1]] - }, - "tests": [ - { - "description": "[true] is invalid", - "data": [true], - "valid": false - }, - { - "description": "[1] is valid", - "data": [1], - "valid": true - }, - { - "description": "[1.0] is valid", - "data": [1.0], - "valid": true - } - ] - }, - { - "description": "nul characters in strings", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "enum": [ "hello\u0000there" ] - }, - "tests": [ - { - "description": "match string with nul", - "data": "hello\u0000there", - "valid": true - }, - { - "description": "do not match string lacking nul", - "data": "hellothere", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/exclusiveMaximum.json b/jsonschema/testdata/draft2020-12/exclusiveMaximum.json deleted file mode 100644 index 05db2335..00000000 --- a/jsonschema/testdata/draft2020-12/exclusiveMaximum.json +++ /dev/null @@ -1,31 +0,0 @@ -[ - { - "description": "exclusiveMaximum validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "exclusiveMaximum": 3.0 - }, - "tests": [ - { - "description": "below the exclusiveMaximum is valid", - "data": 2.2, - "valid": true - }, - { - "description": "boundary point is invalid", - "data": 3.0, - "valid": false - }, - { - "description": "above the exclusiveMaximum is invalid", - "data": 3.5, - "valid": false - }, - { - "description": "ignores non-numbers", - "data": "x", - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/exclusiveMinimum.json b/jsonschema/testdata/draft2020-12/exclusiveMinimum.json deleted file mode 100644 index 00af9d7f..00000000 --- a/jsonschema/testdata/draft2020-12/exclusiveMinimum.json +++ /dev/null @@ -1,31 +0,0 @@ -[ - { - "description": "exclusiveMinimum validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "exclusiveMinimum": 1.1 - }, - "tests": [ - { - "description": "above the exclusiveMinimum is valid", - "data": 1.2, - "valid": true - }, - { - "description": "boundary point is invalid", - "data": 1.1, - "valid": false - }, - { - "description": "below the exclusiveMinimum is invalid", - "data": 0.6, - "valid": false - }, - { - "description": "ignores non-numbers", - "data": "x", - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/if-then-else.json b/jsonschema/testdata/draft2020-12/if-then-else.json deleted file mode 100644 index 1c35d7e6..00000000 --- a/jsonschema/testdata/draft2020-12/if-then-else.json +++ /dev/null @@ -1,268 +0,0 @@ -[ - { - "description": "ignore if without then or else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": { - "const": 0 - } - }, - "tests": [ - { - "description": "valid when valid against lone if", - "data": 0, - "valid": true - }, - { - "description": "valid when invalid against lone if", - "data": "hello", - "valid": true - } - ] - }, - { - "description": "ignore then without if", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "then": { - "const": 0 - } - }, - "tests": [ - { - "description": "valid when valid against lone then", - "data": 0, - "valid": true - }, - { - "description": "valid when invalid against lone then", - "data": "hello", - "valid": true - } - ] - }, - { - "description": "ignore else without if", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "else": { - "const": 0 - } - }, - "tests": [ - { - "description": "valid when valid against lone else", - "data": 0, - "valid": true - }, - { - "description": "valid when invalid against lone else", - "data": "hello", - "valid": true - } - ] - }, - { - "description": "if and then without else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": { - "exclusiveMaximum": 0 - }, - "then": { - "minimum": -10 - } - }, - "tests": [ - { - "description": "valid through then", - "data": -1, - "valid": true - }, - { - "description": "invalid through then", - "data": -100, - "valid": false - }, - { - "description": "valid when if test fails", - "data": 3, - "valid": true - } - ] - }, - { - "description": "if and else without then", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": { - "exclusiveMaximum": 0 - }, - "else": { - "multipleOf": 2 - } - }, - "tests": [ - { - "description": "valid when if test passes", - "data": -1, - "valid": true - }, - { - "description": "valid through else", - "data": 4, - "valid": true - }, - { - "description": "invalid through else", - "data": 3, - "valid": false - } - ] - }, - { - "description": "validate against correct branch, then vs else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": { - "exclusiveMaximum": 0 - }, - "then": { - "minimum": -10 - }, - "else": { - "multipleOf": 2 - } - }, - "tests": [ - { - "description": "valid through then", - "data": -1, - "valid": true - }, - { - "description": "invalid through then", - "data": -100, - "valid": false - }, - { - "description": "valid through else", - "data": 4, - "valid": true - }, - { - "description": "invalid through else", - "data": 3, - "valid": false - } - ] - }, - { - "description": "non-interference across combined schemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "if": { - "exclusiveMaximum": 0 - } - }, - { - "then": { - "minimum": -10 - } - }, - { - "else": { - "multipleOf": 2 - } - } - ] - }, - "tests": [ - { - "description": "valid, but would have been invalid through then", - "data": -100, - "valid": true - }, - { - "description": "valid, but would have been invalid through else", - "data": 3, - "valid": true - } - ] - }, - { - "description": "if with boolean schema true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": true, - "then": { "const": "then" }, - "else": { "const": "else" } - }, - "tests": [ - { - "description": "boolean schema true in if always chooses the then path (valid)", - "data": "then", - "valid": true - }, - { - "description": "boolean schema true in if always chooses the then path (invalid)", - "data": "else", - "valid": false - } - ] - }, - { - "description": "if with boolean schema false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": false, - "then": { "const": "then" }, - "else": { "const": "else" } - }, - "tests": [ - { - "description": "boolean schema false in if always chooses the else path (invalid)", - "data": "then", - "valid": false - }, - { - "description": "boolean schema false in if always chooses the else path (valid)", - "data": "else", - "valid": true - } - ] - }, - { - "description": "if appears at the end when serialized (keyword processing sequence)", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "then": { "const": "yes" }, - "else": { "const": "other" }, - "if": { "maxLength": 4 } - }, - "tests": [ - { - "description": "yes redirects to then and passes", - "data": "yes", - "valid": true - }, - { - "description": "other redirects to else and passes", - "data": "other", - "valid": true - }, - { - "description": "no redirects to then and fails", - "data": "no", - "valid": false - }, - { - "description": "invalid redirects to else and fails", - "data": "invalid", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/infinite-loop-detection.json b/jsonschema/testdata/draft2020-12/infinite-loop-detection.json deleted file mode 100644 index 46f157a3..00000000 --- a/jsonschema/testdata/draft2020-12/infinite-loop-detection.json +++ /dev/null @@ -1,37 +0,0 @@ -[ - { - "description": "evaluating the same schema location against the same data location twice is not a sign of an infinite loop", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "int": { "type": "integer" } - }, - "allOf": [ - { - "properties": { - "foo": { - "$ref": "#/$defs/int" - } - } - }, - { - "additionalProperties": { - "$ref": "#/$defs/int" - } - } - ] - }, - "tests": [ - { - "description": "passing case", - "data": { "foo": 1 }, - "valid": true - }, - { - "description": "failing case", - "data": { "foo": "a string" }, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/items.json b/jsonschema/testdata/draft2020-12/items.json deleted file mode 100644 index 6a3e1cf2..00000000 --- a/jsonschema/testdata/draft2020-12/items.json +++ /dev/null @@ -1,304 +0,0 @@ -[ - { - "description": "a schema given for items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "items": {"type": "integer"} - }, - "tests": [ - { - "description": "valid items", - "data": [ 1, 2, 3 ], - "valid": true - }, - { - "description": "wrong type of items", - "data": [1, "x"], - "valid": false - }, - { - "description": "ignores non-arrays", - "data": {"foo" : "bar"}, - "valid": true - }, - { - "description": "JavaScript pseudo-array is valid", - "data": { - "0": "invalid", - "length": 1 - }, - "valid": true - } - ] - }, - { - "description": "items with boolean schema (true)", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "items": true - }, - "tests": [ - { - "description": "any array is valid", - "data": [ 1, "foo", true ], - "valid": true - }, - { - "description": "empty array is valid", - "data": [], - "valid": true - } - ] - }, - { - "description": "items with boolean schema (false)", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "items": false - }, - "tests": [ - { - "description": "any non-empty array is invalid", - "data": [ 1, "foo", true ], - "valid": false - }, - { - "description": "empty array is valid", - "data": [], - "valid": true - } - ] - }, - { - "description": "items and subitems", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "item": { - "type": "array", - "items": false, - "prefixItems": [ - { "$ref": "#/$defs/sub-item" }, - { "$ref": "#/$defs/sub-item" } - ] - }, - "sub-item": { - "type": "object", - "required": ["foo"] - } - }, - "type": "array", - "items": false, - "prefixItems": [ - { "$ref": "#/$defs/item" }, - { "$ref": "#/$defs/item" }, - { "$ref": "#/$defs/item" } - ] - }, - "tests": [ - { - "description": "valid items", - "data": [ - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ] - ], - "valid": true - }, - { - "description": "too many items", - "data": [ - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ] - ], - "valid": false - }, - { - "description": "too many sub-items", - "data": [ - [ {"foo": null}, {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ] - ], - "valid": false - }, - { - "description": "wrong item", - "data": [ - {"foo": null}, - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ] - ], - "valid": false - }, - { - "description": "wrong sub-item", - "data": [ - [ {}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ], - [ {"foo": null}, {"foo": null} ] - ], - "valid": false - }, - { - "description": "fewer items is valid", - "data": [ - [ {"foo": null} ], - [ {"foo": null} ] - ], - "valid": true - } - ] - }, - { - "description": "nested items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "array", - "items": { - "type": "array", - "items": { - "type": "array", - "items": { - "type": "array", - "items": { - "type": "number" - } - } - } - } - }, - "tests": [ - { - "description": "valid nested array", - "data": [[[[1]], [[2],[3]]], [[[4], [5], [6]]]], - "valid": true - }, - { - "description": "nested array with invalid type", - "data": [[[["1"]], [[2],[3]]], [[[4], [5], [6]]]], - "valid": false - }, - { - "description": "not deep enough", - "data": [[[1], [2],[3]], [[4], [5], [6]]], - "valid": false - } - ] - }, - { - "description": "prefixItems with no additional items allowed", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [{}, {}, {}], - "items": false - }, - "tests": [ - { - "description": "empty array", - "data": [ ], - "valid": true - }, - { - "description": "fewer number of items present (1)", - "data": [ 1 ], - "valid": true - }, - { - "description": "fewer number of items present (2)", - "data": [ 1, 2 ], - "valid": true - }, - { - "description": "equal number of items present", - "data": [ 1, 2, 3 ], - "valid": true - }, - { - "description": "additional items are not permitted", - "data": [ 1, 2, 3, 4 ], - "valid": false - } - ] - }, - { - "description": "items does not look in applicators, valid case", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { "prefixItems": [ { "minimum": 3 } ] } - ], - "items": { "minimum": 5 } - }, - "tests": [ - { - "description": "prefixItems in allOf does not constrain items, invalid case", - "data": [ 3, 5 ], - "valid": false - }, - { - "description": "prefixItems in allOf does not constrain items, valid case", - "data": [ 5, 5 ], - "valid": true - } - ] - }, - { - "description": "prefixItems validation adjusts the starting index for items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ { "type": "string" } ], - "items": { "type": "integer" } - }, - "tests": [ - { - "description": "valid items", - "data": [ "x", 2, 3 ], - "valid": true - }, - { - "description": "wrong type of second item", - "data": [ "x", "y" ], - "valid": false - } - ] - }, - { - "description": "items with heterogeneous array", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [{}], - "items": false - }, - "tests": [ - { - "description": "heterogeneous invalid instance", - "data": [ "foo", "bar", 37 ], - "valid": false - }, - { - "description": "valid instance", - "data": [ null ], - "valid": true - } - ] - }, - { - "description": "items with null instance elements", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "items": { - "type": "null" - } - }, - "tests": [ - { - "description": "allows null elements", - "data": [ null ], - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/maxContains.json b/jsonschema/testdata/draft2020-12/maxContains.json deleted file mode 100644 index 8cd3ca74..00000000 --- a/jsonschema/testdata/draft2020-12/maxContains.json +++ /dev/null @@ -1,102 +0,0 @@ -[ - { - "description": "maxContains without contains is ignored", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxContains": 1 - }, - "tests": [ - { - "description": "one item valid against lone maxContains", - "data": [ 1 ], - "valid": true - }, - { - "description": "two items still valid against lone maxContains", - "data": [ 1, 2 ], - "valid": true - } - ] - }, - { - "description": "maxContains with contains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "maxContains": 1 - }, - "tests": [ - { - "description": "empty data", - "data": [ ], - "valid": false - }, - { - "description": "all elements match, valid maxContains", - "data": [ 1 ], - "valid": true - }, - { - "description": "all elements match, invalid maxContains", - "data": [ 1, 1 ], - "valid": false - }, - { - "description": "some elements match, valid maxContains", - "data": [ 1, 2 ], - "valid": true - }, - { - "description": "some elements match, invalid maxContains", - "data": [ 1, 2, 1 ], - "valid": false - } - ] - }, - { - "description": "maxContains with contains, value with a decimal", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "maxContains": 1.0 - }, - "tests": [ - { - "description": "one element matches, valid maxContains", - "data": [ 1 ], - "valid": true - }, - { - "description": "too many elements match, invalid maxContains", - "data": [ 1, 1 ], - "valid": false - } - ] - }, - { - "description": "minContains < maxContains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "minContains": 1, - "maxContains": 3 - }, - "tests": [ - { - "description": "actual < minContains < maxContains", - "data": [ ], - "valid": false - }, - { - "description": "minContains < actual < maxContains", - "data": [ 1, 1 ], - "valid": true - }, - { - "description": "minContains < maxContains < actual", - "data": [ 1, 1, 1, 1 ], - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/maxItems.json b/jsonschema/testdata/draft2020-12/maxItems.json deleted file mode 100644 index f6a6b7c9..00000000 --- a/jsonschema/testdata/draft2020-12/maxItems.json +++ /dev/null @@ -1,50 +0,0 @@ -[ - { - "description": "maxItems validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxItems": 2 - }, - "tests": [ - { - "description": "shorter is valid", - "data": [1], - "valid": true - }, - { - "description": "exact length is valid", - "data": [1, 2], - "valid": true - }, - { - "description": "too long is invalid", - "data": [1, 2, 3], - "valid": false - }, - { - "description": "ignores non-arrays", - "data": "foobar", - "valid": true - } - ] - }, - { - "description": "maxItems validation with a decimal", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxItems": 2.0 - }, - "tests": [ - { - "description": "shorter is valid", - "data": [1], - "valid": true - }, - { - "description": "too long is invalid", - "data": [1, 2, 3], - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/maxLength.json b/jsonschema/testdata/draft2020-12/maxLength.json deleted file mode 100644 index 7462726d..00000000 --- a/jsonschema/testdata/draft2020-12/maxLength.json +++ /dev/null @@ -1,55 +0,0 @@ -[ - { - "description": "maxLength validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxLength": 2 - }, - "tests": [ - { - "description": "shorter is valid", - "data": "f", - "valid": true - }, - { - "description": "exact length is valid", - "data": "fo", - "valid": true - }, - { - "description": "too long is invalid", - "data": "foo", - "valid": false - }, - { - "description": "ignores non-strings", - "data": 100, - "valid": true - }, - { - "description": "two graphemes is long enough", - "data": "\uD83D\uDCA9\uD83D\uDCA9", - "valid": true - } - ] - }, - { - "description": "maxLength validation with a decimal", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxLength": 2.0 - }, - "tests": [ - { - "description": "shorter is valid", - "data": "f", - "valid": true - }, - { - "description": "too long is invalid", - "data": "foo", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/maxProperties.json b/jsonschema/testdata/draft2020-12/maxProperties.json deleted file mode 100644 index 73ae7316..00000000 --- a/jsonschema/testdata/draft2020-12/maxProperties.json +++ /dev/null @@ -1,79 +0,0 @@ -[ - { - "description": "maxProperties validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxProperties": 2 - }, - "tests": [ - { - "description": "shorter is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "exact length is valid", - "data": {"foo": 1, "bar": 2}, - "valid": true - }, - { - "description": "too long is invalid", - "data": {"foo": 1, "bar": 2, "baz": 3}, - "valid": false - }, - { - "description": "ignores arrays", - "data": [1, 2, 3], - "valid": true - }, - { - "description": "ignores strings", - "data": "foobar", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": "maxProperties validation with a decimal", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxProperties": 2.0 - }, - "tests": [ - { - "description": "shorter is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "too long is invalid", - "data": {"foo": 1, "bar": 2, "baz": 3}, - "valid": false - } - ] - }, - { - "description": "maxProperties = 0 means the object is empty", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maxProperties": 0 - }, - "tests": [ - { - "description": "no properties is valid", - "data": {}, - "valid": true - }, - { - "description": "one property is invalid", - "data": { "foo": 1 }, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/maximum.json b/jsonschema/testdata/draft2020-12/maximum.json deleted file mode 100644 index b99a541e..00000000 --- a/jsonschema/testdata/draft2020-12/maximum.json +++ /dev/null @@ -1,60 +0,0 @@ -[ - { - "description": "maximum validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maximum": 3.0 - }, - "tests": [ - { - "description": "below the maximum is valid", - "data": 2.6, - "valid": true - }, - { - "description": "boundary point is valid", - "data": 3.0, - "valid": true - }, - { - "description": "above the maximum is invalid", - "data": 3.5, - "valid": false - }, - { - "description": "ignores non-numbers", - "data": "x", - "valid": true - } - ] - }, - { - "description": "maximum validation with unsigned integer", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "maximum": 300 - }, - "tests": [ - { - "description": "below the maximum is invalid", - "data": 299.97, - "valid": true - }, - { - "description": "boundary point integer is valid", - "data": 300, - "valid": true - }, - { - "description": "boundary point float is valid", - "data": 300.00, - "valid": true - }, - { - "description": "above the maximum is invalid", - "data": 300.5, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/minContains.json b/jsonschema/testdata/draft2020-12/minContains.json deleted file mode 100644 index ee72d7d6..00000000 --- a/jsonschema/testdata/draft2020-12/minContains.json +++ /dev/null @@ -1,224 +0,0 @@ -[ - { - "description": "minContains without contains is ignored", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minContains": 1 - }, - "tests": [ - { - "description": "one item valid against lone minContains", - "data": [ 1 ], - "valid": true - }, - { - "description": "zero items still valid against lone minContains", - "data": [], - "valid": true - } - ] - }, - { - "description": "minContains=1 with contains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "minContains": 1 - }, - "tests": [ - { - "description": "empty data", - "data": [ ], - "valid": false - }, - { - "description": "no elements match", - "data": [ 2 ], - "valid": false - }, - { - "description": "single element matches, valid minContains", - "data": [ 1 ], - "valid": true - }, - { - "description": "some elements match, valid minContains", - "data": [ 1, 2 ], - "valid": true - }, - { - "description": "all elements match, valid minContains", - "data": [ 1, 1 ], - "valid": true - } - ] - }, - { - "description": "minContains=2 with contains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "minContains": 2 - }, - "tests": [ - { - "description": "empty data", - "data": [ ], - "valid": false - }, - { - "description": "all elements match, invalid minContains", - "data": [ 1 ], - "valid": false - }, - { - "description": "some elements match, invalid minContains", - "data": [ 1, 2 ], - "valid": false - }, - { - "description": "all elements match, valid minContains (exactly as needed)", - "data": [ 1, 1 ], - "valid": true - }, - { - "description": "all elements match, valid minContains (more than needed)", - "data": [ 1, 1, 1 ], - "valid": true - }, - { - "description": "some elements match, valid minContains", - "data": [ 1, 2, 1 ], - "valid": true - } - ] - }, - { - "description": "minContains=2 with contains with a decimal value", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "minContains": 2.0 - }, - "tests": [ - { - "description": "one element matches, invalid minContains", - "data": [ 1 ], - "valid": false - }, - { - "description": "both elements match, valid minContains", - "data": [ 1, 1 ], - "valid": true - } - ] - }, - { - "description": "maxContains = minContains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "maxContains": 2, - "minContains": 2 - }, - "tests": [ - { - "description": "empty data", - "data": [ ], - "valid": false - }, - { - "description": "all elements match, invalid minContains", - "data": [ 1 ], - "valid": false - }, - { - "description": "all elements match, invalid maxContains", - "data": [ 1, 1, 1 ], - "valid": false - }, - { - "description": "all elements match, valid maxContains and minContains", - "data": [ 1, 1 ], - "valid": true - } - ] - }, - { - "description": "maxContains < minContains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "maxContains": 1, - "minContains": 3 - }, - "tests": [ - { - "description": "empty data", - "data": [ ], - "valid": false - }, - { - "description": "invalid minContains", - "data": [ 1 ], - "valid": false - }, - { - "description": "invalid maxContains", - "data": [ 1, 1, 1 ], - "valid": false - }, - { - "description": "invalid maxContains and minContains", - "data": [ 1, 1 ], - "valid": false - } - ] - }, - { - "description": "minContains = 0", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "minContains": 0 - }, - "tests": [ - { - "description": "empty data", - "data": [ ], - "valid": true - }, - { - "description": "minContains = 0 makes contains always pass", - "data": [ 2 ], - "valid": true - } - ] - }, - { - "description": "minContains = 0 with maxContains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "contains": {"const": 1}, - "minContains": 0, - "maxContains": 1 - }, - "tests": [ - { - "description": "empty data", - "data": [ ], - "valid": true - }, - { - "description": "not more than maxContains", - "data": [ 1 ], - "valid": true - }, - { - "description": "too many", - "data": [ 1, 1 ], - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/minItems.json b/jsonschema/testdata/draft2020-12/minItems.json deleted file mode 100644 index 9d6a8b6d..00000000 --- a/jsonschema/testdata/draft2020-12/minItems.json +++ /dev/null @@ -1,50 +0,0 @@ -[ - { - "description": "minItems validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minItems": 1 - }, - "tests": [ - { - "description": "longer is valid", - "data": [1, 2], - "valid": true - }, - { - "description": "exact length is valid", - "data": [1], - "valid": true - }, - { - "description": "too short is invalid", - "data": [], - "valid": false - }, - { - "description": "ignores non-arrays", - "data": "", - "valid": true - } - ] - }, - { - "description": "minItems validation with a decimal", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minItems": 1.0 - }, - "tests": [ - { - "description": "longer is valid", - "data": [1, 2], - "valid": true - }, - { - "description": "too short is invalid", - "data": [], - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/minLength.json b/jsonschema/testdata/draft2020-12/minLength.json deleted file mode 100644 index 5076c5a9..00000000 --- a/jsonschema/testdata/draft2020-12/minLength.json +++ /dev/null @@ -1,55 +0,0 @@ -[ - { - "description": "minLength validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minLength": 2 - }, - "tests": [ - { - "description": "longer is valid", - "data": "foo", - "valid": true - }, - { - "description": "exact length is valid", - "data": "fo", - "valid": true - }, - { - "description": "too short is invalid", - "data": "f", - "valid": false - }, - { - "description": "ignores non-strings", - "data": 1, - "valid": true - }, - { - "description": "one grapheme is not long enough", - "data": "\uD83D\uDCA9", - "valid": false - } - ] - }, - { - "description": "minLength validation with a decimal", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minLength": 2.0 - }, - "tests": [ - { - "description": "longer is valid", - "data": "foo", - "valid": true - }, - { - "description": "too short is invalid", - "data": "f", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/minProperties.json b/jsonschema/testdata/draft2020-12/minProperties.json deleted file mode 100644 index a753ad35..00000000 --- a/jsonschema/testdata/draft2020-12/minProperties.json +++ /dev/null @@ -1,60 +0,0 @@ -[ - { - "description": "minProperties validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minProperties": 1 - }, - "tests": [ - { - "description": "longer is valid", - "data": {"foo": 1, "bar": 2}, - "valid": true - }, - { - "description": "exact length is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "too short is invalid", - "data": {}, - "valid": false - }, - { - "description": "ignores arrays", - "data": [], - "valid": true - }, - { - "description": "ignores strings", - "data": "", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": "minProperties validation with a decimal", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minProperties": 1.0 - }, - "tests": [ - { - "description": "longer is valid", - "data": {"foo": 1, "bar": 2}, - "valid": true - }, - { - "description": "too short is invalid", - "data": {}, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/minimum.json b/jsonschema/testdata/draft2020-12/minimum.json deleted file mode 100644 index dc440527..00000000 --- a/jsonschema/testdata/draft2020-12/minimum.json +++ /dev/null @@ -1,75 +0,0 @@ -[ - { - "description": "minimum validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minimum": 1.1 - }, - "tests": [ - { - "description": "above the minimum is valid", - "data": 2.6, - "valid": true - }, - { - "description": "boundary point is valid", - "data": 1.1, - "valid": true - }, - { - "description": "below the minimum is invalid", - "data": 0.6, - "valid": false - }, - { - "description": "ignores non-numbers", - "data": "x", - "valid": true - } - ] - }, - { - "description": "minimum validation with signed integer", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "minimum": -2 - }, - "tests": [ - { - "description": "negative above the minimum is valid", - "data": -1, - "valid": true - }, - { - "description": "positive above the minimum is valid", - "data": 0, - "valid": true - }, - { - "description": "boundary point is valid", - "data": -2, - "valid": true - }, - { - "description": "boundary point with float is valid", - "data": -2.0, - "valid": true - }, - { - "description": "float below the minimum is invalid", - "data": -2.0001, - "valid": false - }, - { - "description": "int below the minimum is invalid", - "data": -3, - "valid": false - }, - { - "description": "ignores non-numbers", - "data": "x", - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/multipleOf.json b/jsonschema/testdata/draft2020-12/multipleOf.json deleted file mode 100644 index 92d6979b..00000000 --- a/jsonschema/testdata/draft2020-12/multipleOf.json +++ /dev/null @@ -1,97 +0,0 @@ -[ - { - "description": "by int", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "multipleOf": 2 - }, - "tests": [ - { - "description": "int by int", - "data": 10, - "valid": true - }, - { - "description": "int by int fail", - "data": 7, - "valid": false - }, - { - "description": "ignores non-numbers", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "by number", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "multipleOf": 1.5 - }, - "tests": [ - { - "description": "zero is multiple of anything", - "data": 0, - "valid": true - }, - { - "description": "4.5 is multiple of 1.5", - "data": 4.5, - "valid": true - }, - { - "description": "35 is not multiple of 1.5", - "data": 35, - "valid": false - } - ] - }, - { - "description": "by small number", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "multipleOf": 0.0001 - }, - "tests": [ - { - "description": "0.0075 is multiple of 0.0001", - "data": 0.0075, - "valid": true - }, - { - "description": "0.00751 is not multiple of 0.0001", - "data": 0.00751, - "valid": false - } - ] - }, - { - "description": "float division = inf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "integer", "multipleOf": 0.123456789 - }, - "tests": [ - { - "description": "always invalid, but naive implementations may raise an overflow error", - "data": 1e308, - "valid": false - } - ] - }, - { - "description": "small multiple of large integer", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "integer", "multipleOf": 1e-8 - }, - "tests": [ - { - "description": "any integer is a multiple of 1e-8", - "data": 12391239123, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/not.json b/jsonschema/testdata/draft2020-12/not.json deleted file mode 100644 index 346d4a7e..00000000 --- a/jsonschema/testdata/draft2020-12/not.json +++ /dev/null @@ -1,301 +0,0 @@ -[ - { - "description": "not", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": {"type": "integer"} - }, - "tests": [ - { - "description": "allowed", - "data": "foo", - "valid": true - }, - { - "description": "disallowed", - "data": 1, - "valid": false - } - ] - }, - { - "description": "not multiple types", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": {"type": ["integer", "boolean"]} - }, - "tests": [ - { - "description": "valid", - "data": "foo", - "valid": true - }, - { - "description": "mismatch", - "data": 1, - "valid": false - }, - { - "description": "other mismatch", - "data": true, - "valid": false - } - ] - }, - { - "description": "not more complex schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": { - "type": "object", - "properties": { - "foo": { - "type": "string" - } - } - } - }, - "tests": [ - { - "description": "match", - "data": 1, - "valid": true - }, - { - "description": "other match", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "mismatch", - "data": {"foo": "bar"}, - "valid": false - } - ] - }, - { - "description": "forbidden property", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": { - "not": {} - } - } - }, - "tests": [ - { - "description": "property present", - "data": {"foo": 1, "bar": 2}, - "valid": false - }, - { - "description": "property absent", - "data": {"bar": 1, "baz": 2}, - "valid": true - } - ] - }, - { - "description": "forbid everything with empty schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": {} - }, - "tests": [ - { - "description": "number is invalid", - "data": 1, - "valid": false - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - }, - { - "description": "boolean true is invalid", - "data": true, - "valid": false - }, - { - "description": "boolean false is invalid", - "data": false, - "valid": false - }, - { - "description": "null is invalid", - "data": null, - "valid": false - }, - { - "description": "object is invalid", - "data": {"foo": "bar"}, - "valid": false - }, - { - "description": "empty object is invalid", - "data": {}, - "valid": false - }, - { - "description": "array is invalid", - "data": ["foo"], - "valid": false - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - } - ] - }, - { - "description": "forbid everything with boolean schema true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": true - }, - "tests": [ - { - "description": "number is invalid", - "data": 1, - "valid": false - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - }, - { - "description": "boolean true is invalid", - "data": true, - "valid": false - }, - { - "description": "boolean false is invalid", - "data": false, - "valid": false - }, - { - "description": "null is invalid", - "data": null, - "valid": false - }, - { - "description": "object is invalid", - "data": {"foo": "bar"}, - "valid": false - }, - { - "description": "empty object is invalid", - "data": {}, - "valid": false - }, - { - "description": "array is invalid", - "data": ["foo"], - "valid": false - }, - { - "description": "empty array is invalid", - "data": [], - "valid": false - } - ] - }, - { - "description": "allow everything with boolean schema false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": false - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "string is valid", - "data": "foo", - "valid": true - }, - { - "description": "boolean true is valid", - "data": true, - "valid": true - }, - { - "description": "boolean false is valid", - "data": false, - "valid": true - }, - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "object is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - }, - { - "description": "array is valid", - "data": ["foo"], - "valid": true - }, - { - "description": "empty array is valid", - "data": [], - "valid": true - } - ] - }, - { - "description": "double negation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": { "not": {} } - }, - "tests": [ - { - "description": "any value is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "collect annotations inside a 'not', even if collection is disabled", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "not": { - "$comment": "this subschema must still produce annotations internally, even though the 'not' will ultimately discard them", - "anyOf": [ - true, - { "properties": { "foo": true } } - ], - "unevaluatedProperties": false - } - }, - "tests": [ - { - "description": "unevaluated property", - "data": { "bar": 1 }, - "valid": true - }, - { - "description": "annotations are still collected inside a 'not'", - "data": { "foo": 1 }, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/oneOf.json b/jsonschema/testdata/draft2020-12/oneOf.json deleted file mode 100644 index 7a7c7ffe..00000000 --- a/jsonschema/testdata/draft2020-12/oneOf.json +++ /dev/null @@ -1,293 +0,0 @@ -[ - { - "description": "oneOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [ - { - "type": "integer" - }, - { - "minimum": 2 - } - ] - }, - "tests": [ - { - "description": "first oneOf valid", - "data": 1, - "valid": true - }, - { - "description": "second oneOf valid", - "data": 2.5, - "valid": true - }, - { - "description": "both oneOf valid", - "data": 3, - "valid": false - }, - { - "description": "neither oneOf valid", - "data": 1.5, - "valid": false - } - ] - }, - { - "description": "oneOf with base schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "string", - "oneOf" : [ - { - "minLength": 2 - }, - { - "maxLength": 4 - } - ] - }, - "tests": [ - { - "description": "mismatch base schema", - "data": 3, - "valid": false - }, - { - "description": "one oneOf valid", - "data": "foobar", - "valid": true - }, - { - "description": "both oneOf valid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "oneOf with boolean schemas, all true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [true, true, true] - }, - "tests": [ - { - "description": "any value is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "oneOf with boolean schemas, one true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [true, false, false] - }, - "tests": [ - { - "description": "any value is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "oneOf with boolean schemas, more than one true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [true, true, false] - }, - "tests": [ - { - "description": "any value is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "oneOf with boolean schemas, all false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [false, false, false] - }, - "tests": [ - { - "description": "any value is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "oneOf complex types", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [ - { - "properties": { - "bar": {"type": "integer"} - }, - "required": ["bar"] - }, - { - "properties": { - "foo": {"type": "string"} - }, - "required": ["foo"] - } - ] - }, - "tests": [ - { - "description": "first oneOf valid (complex)", - "data": {"bar": 2}, - "valid": true - }, - { - "description": "second oneOf valid (complex)", - "data": {"foo": "baz"}, - "valid": true - }, - { - "description": "both oneOf valid (complex)", - "data": {"foo": "baz", "bar": 2}, - "valid": false - }, - { - "description": "neither oneOf valid (complex)", - "data": {"foo": 2, "bar": "quux"}, - "valid": false - } - ] - }, - { - "description": "oneOf with empty schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [ - { "type": "number" }, - {} - ] - }, - "tests": [ - { - "description": "one valid - valid", - "data": "foo", - "valid": true - }, - { - "description": "both valid - invalid", - "data": 123, - "valid": false - } - ] - }, - { - "description": "oneOf with required", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "oneOf": [ - { "required": ["foo", "bar"] }, - { "required": ["foo", "baz"] } - ] - }, - "tests": [ - { - "description": "both invalid - invalid", - "data": {"bar": 2}, - "valid": false - }, - { - "description": "first valid - valid", - "data": {"foo": 1, "bar": 2}, - "valid": true - }, - { - "description": "second valid - valid", - "data": {"foo": 1, "baz": 3}, - "valid": true - }, - { - "description": "both valid - invalid", - "data": {"foo": 1, "bar": 2, "baz" : 3}, - "valid": false - } - ] - }, - { - "description": "oneOf with missing optional property", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [ - { - "properties": { - "bar": true, - "baz": true - }, - "required": ["bar"] - }, - { - "properties": { - "foo": true - }, - "required": ["foo"] - } - ] - }, - "tests": [ - { - "description": "first oneOf valid", - "data": {"bar": 8}, - "valid": true - }, - { - "description": "second oneOf valid", - "data": {"foo": "foo"}, - "valid": true - }, - { - "description": "both oneOf valid", - "data": {"foo": "foo", "bar": 8}, - "valid": false - }, - { - "description": "neither oneOf valid", - "data": {"baz": "quux"}, - "valid": false - } - ] - }, - { - "description": "nested oneOf, to check validation semantics", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "oneOf": [ - { - "oneOf": [ - { - "type": "null" - } - ] - } - ] - }, - "tests": [ - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "anything non-null is invalid", - "data": 123, - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/pattern.json b/jsonschema/testdata/draft2020-12/pattern.json deleted file mode 100644 index af0b8d89..00000000 --- a/jsonschema/testdata/draft2020-12/pattern.json +++ /dev/null @@ -1,65 +0,0 @@ -[ - { - "description": "pattern validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "pattern": "^a*$" - }, - "tests": [ - { - "description": "a matching pattern is valid", - "data": "aaa", - "valid": true - }, - { - "description": "a non-matching pattern is invalid", - "data": "abc", - "valid": false - }, - { - "description": "ignores booleans", - "data": true, - "valid": true - }, - { - "description": "ignores integers", - "data": 123, - "valid": true - }, - { - "description": "ignores floats", - "data": 1.0, - "valid": true - }, - { - "description": "ignores objects", - "data": {}, - "valid": true - }, - { - "description": "ignores arrays", - "data": [], - "valid": true - }, - { - "description": "ignores null", - "data": null, - "valid": true - } - ] - }, - { - "description": "pattern is not anchored", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "pattern": "a+" - }, - "tests": [ - { - "description": "matches a substring", - "data": "xxaayy", - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/patternProperties.json b/jsonschema/testdata/draft2020-12/patternProperties.json deleted file mode 100644 index 81829c71..00000000 --- a/jsonschema/testdata/draft2020-12/patternProperties.json +++ /dev/null @@ -1,176 +0,0 @@ -[ - { - "description": - "patternProperties validates properties matching a regex", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "patternProperties": { - "f.*o": {"type": "integer"} - } - }, - "tests": [ - { - "description": "a single valid match is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "multiple valid matches is valid", - "data": {"foo": 1, "foooooo" : 2}, - "valid": true - }, - { - "description": "a single invalid match is invalid", - "data": {"foo": "bar", "fooooo": 2}, - "valid": false - }, - { - "description": "multiple invalid matches is invalid", - "data": {"foo": "bar", "foooooo" : "baz"}, - "valid": false - }, - { - "description": "ignores arrays", - "data": ["foo"], - "valid": true - }, - { - "description": "ignores strings", - "data": "foo", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": "multiple simultaneous patternProperties are validated", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "patternProperties": { - "a*": {"type": "integer"}, - "aaa*": {"maximum": 20} - } - }, - "tests": [ - { - "description": "a single valid match is valid", - "data": {"a": 21}, - "valid": true - }, - { - "description": "a simultaneous match is valid", - "data": {"aaaa": 18}, - "valid": true - }, - { - "description": "multiple matches is valid", - "data": {"a": 21, "aaaa": 18}, - "valid": true - }, - { - "description": "an invalid due to one is invalid", - "data": {"a": "bar"}, - "valid": false - }, - { - "description": "an invalid due to the other is invalid", - "data": {"aaaa": 31}, - "valid": false - }, - { - "description": "an invalid due to both is invalid", - "data": {"aaa": "foo", "aaaa": 31}, - "valid": false - } - ] - }, - { - "description": "regexes are not anchored by default and are case sensitive", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "patternProperties": { - "[0-9]{2,}": { "type": "boolean" }, - "X_": { "type": "string" } - } - }, - "tests": [ - { - "description": "non recognized members are ignored", - "data": { "answer 1": "42" }, - "valid": true - }, - { - "description": "recognized members are accounted for", - "data": { "a31b": null }, - "valid": false - }, - { - "description": "regexes are case sensitive", - "data": { "a_x_3": 3 }, - "valid": true - }, - { - "description": "regexes are case sensitive, 2", - "data": { "a_X_3": 3 }, - "valid": false - } - ] - }, - { - "description": "patternProperties with boolean schemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "patternProperties": { - "f.*": true, - "b.*": false - } - }, - "tests": [ - { - "description": "object with property matching schema true is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "object with property matching schema false is invalid", - "data": {"bar": 2}, - "valid": false - }, - { - "description": "object with both properties is invalid", - "data": {"foo": 1, "bar": 2}, - "valid": false - }, - { - "description": "object with a property matching both true and false is invalid", - "data": {"foobar":1}, - "valid": false - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - } - ] - }, - { - "description": "patternProperties with null valued instance properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "patternProperties": { - "^.*bar$": {"type": "null"} - } - }, - "tests": [ - { - "description": "allows null values", - "data": {"foobar": null}, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/prefixItems.json b/jsonschema/testdata/draft2020-12/prefixItems.json deleted file mode 100644 index 0adfc069..00000000 --- a/jsonschema/testdata/draft2020-12/prefixItems.json +++ /dev/null @@ -1,104 +0,0 @@ -[ - { - "description": "a schema given for prefixItems", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - {"type": "integer"}, - {"type": "string"} - ] - }, - "tests": [ - { - "description": "correct types", - "data": [ 1, "foo" ], - "valid": true - }, - { - "description": "wrong types", - "data": [ "foo", 1 ], - "valid": false - }, - { - "description": "incomplete array of items", - "data": [ 1 ], - "valid": true - }, - { - "description": "array with additional items", - "data": [ 1, "foo", true ], - "valid": true - }, - { - "description": "empty array", - "data": [ ], - "valid": true - }, - { - "description": "JavaScript pseudo-array is valid", - "data": { - "0": "invalid", - "1": "valid", - "length": 2 - }, - "valid": true - } - ] - }, - { - "description": "prefixItems with boolean schemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [true, false] - }, - "tests": [ - { - "description": "array with one item is valid", - "data": [ 1 ], - "valid": true - }, - { - "description": "array with two items is invalid", - "data": [ 1, "foo" ], - "valid": false - }, - { - "description": "empty array is valid", - "data": [], - "valid": true - } - ] - }, - { - "description": "additional items are allowed by default", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [{"type": "integer"}] - }, - "tests": [ - { - "description": "only the first item is validated", - "data": [1, "foo", false], - "valid": true - } - ] - }, - { - "description": "prefixItems with null instance elements", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { - "type": "null" - } - ] - }, - "tests": [ - { - "description": "allows null elements", - "data": [ null ], - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/properties.json b/jsonschema/testdata/draft2020-12/properties.json deleted file mode 100644 index 523dcde7..00000000 --- a/jsonschema/testdata/draft2020-12/properties.json +++ /dev/null @@ -1,242 +0,0 @@ -[ - { - "description": "object properties validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {"type": "integer"}, - "bar": {"type": "string"} - } - }, - "tests": [ - { - "description": "both properties present and valid is valid", - "data": {"foo": 1, "bar": "baz"}, - "valid": true - }, - { - "description": "one property invalid is invalid", - "data": {"foo": 1, "bar": {}}, - "valid": false - }, - { - "description": "both properties invalid is invalid", - "data": {"foo": [], "bar": {}}, - "valid": false - }, - { - "description": "doesn't invalidate other properties", - "data": {"quux": []}, - "valid": true - }, - { - "description": "ignores arrays", - "data": [], - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": - "properties, patternProperties, additionalProperties interaction", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {"type": "array", "maxItems": 3}, - "bar": {"type": "array"} - }, - "patternProperties": {"f.o": {"minItems": 2}}, - "additionalProperties": {"type": "integer"} - }, - "tests": [ - { - "description": "property validates property", - "data": {"foo": [1, 2]}, - "valid": true - }, - { - "description": "property invalidates property", - "data": {"foo": [1, 2, 3, 4]}, - "valid": false - }, - { - "description": "patternProperty invalidates property", - "data": {"foo": []}, - "valid": false - }, - { - "description": "patternProperty validates nonproperty", - "data": {"fxo": [1, 2]}, - "valid": true - }, - { - "description": "patternProperty invalidates nonproperty", - "data": {"fxo": []}, - "valid": false - }, - { - "description": "additionalProperty ignores property", - "data": {"bar": []}, - "valid": true - }, - { - "description": "additionalProperty validates others", - "data": {"quux": 3}, - "valid": true - }, - { - "description": "additionalProperty invalidates others", - "data": {"quux": "foo"}, - "valid": false - } - ] - }, - { - "description": "properties with boolean schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": true, - "bar": false - } - }, - "tests": [ - { - "description": "no property present is valid", - "data": {}, - "valid": true - }, - { - "description": "only 'true' property present is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "only 'false' property present is invalid", - "data": {"bar": 2}, - "valid": false - }, - { - "description": "both properties present is invalid", - "data": {"foo": 1, "bar": 2}, - "valid": false - } - ] - }, - { - "description": "properties with escaped characters", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo\nbar": {"type": "number"}, - "foo\"bar": {"type": "number"}, - "foo\\bar": {"type": "number"}, - "foo\rbar": {"type": "number"}, - "foo\tbar": {"type": "number"}, - "foo\fbar": {"type": "number"} - } - }, - "tests": [ - { - "description": "object with all numbers is valid", - "data": { - "foo\nbar": 1, - "foo\"bar": 1, - "foo\\bar": 1, - "foo\rbar": 1, - "foo\tbar": 1, - "foo\fbar": 1 - }, - "valid": true - }, - { - "description": "object with strings is invalid", - "data": { - "foo\nbar": "1", - "foo\"bar": "1", - "foo\\bar": "1", - "foo\rbar": "1", - "foo\tbar": "1", - "foo\fbar": "1" - }, - "valid": false - } - ] - }, - { - "description": "properties with null valued instance properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {"type": "null"} - } - }, - "tests": [ - { - "description": "allows null values", - "data": {"foo": null}, - "valid": true - } - ] - }, - { - "description": "properties whose names are Javascript object property names", - "comment": "Ensure JS implementations don't universally consider e.g. __proto__ to always be present in an object.", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "__proto__": {"type": "number"}, - "toString": { - "properties": { "length": { "type": "string" } } - }, - "constructor": {"type": "number"} - } - }, - "tests": [ - { - "description": "ignores arrays", - "data": [], - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - }, - { - "description": "none of the properties mentioned", - "data": {}, - "valid": true - }, - { - "description": "__proto__ not valid", - "data": { "__proto__": "foo" }, - "valid": false - }, - { - "description": "toString not valid", - "data": { "toString": { "length": 37 } }, - "valid": false - }, - { - "description": "constructor not valid", - "data": { "constructor": { "length": 37 } }, - "valid": false - }, - { - "description": "all present and valid", - "data": { - "__proto__": 12, - "toString": { "length": "foo" }, - "constructor": 37 - }, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/propertyNames.json b/jsonschema/testdata/draft2020-12/propertyNames.json deleted file mode 100644 index b4780088..00000000 --- a/jsonschema/testdata/draft2020-12/propertyNames.json +++ /dev/null @@ -1,168 +0,0 @@ -[ - { - "description": "propertyNames validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": {"maxLength": 3} - }, - "tests": [ - { - "description": "all property names valid", - "data": { - "f": {}, - "foo": {} - }, - "valid": true - }, - { - "description": "some property names invalid", - "data": { - "foo": {}, - "foobar": {} - }, - "valid": false - }, - { - "description": "object without properties is valid", - "data": {}, - "valid": true - }, - { - "description": "ignores arrays", - "data": [1, 2, 3, 4], - "valid": true - }, - { - "description": "ignores strings", - "data": "foobar", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": "propertyNames validation with pattern", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": { "pattern": "^a+$" } - }, - "tests": [ - { - "description": "matching property names valid", - "data": { - "a": {}, - "aa": {}, - "aaa": {} - }, - "valid": true - }, - { - "description": "non-matching property name is invalid", - "data": { - "aaA": {} - }, - "valid": false - }, - { - "description": "object without properties is valid", - "data": {}, - "valid": true - } - ] - }, - { - "description": "propertyNames with boolean schema true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": true - }, - "tests": [ - { - "description": "object with any properties is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - } - ] - }, - { - "description": "propertyNames with boolean schema false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": false - }, - "tests": [ - { - "description": "object with any properties is invalid", - "data": {"foo": 1}, - "valid": false - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - } - ] - }, - { - "description": "propertyNames with const", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": {"const": "foo"} - }, - "tests": [ - { - "description": "object with property foo is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "object with any other property is invalid", - "data": {"bar": 1}, - "valid": false - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - } - ] - }, - { - "description": "propertyNames with enum", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": {"enum": ["foo", "bar"]} - }, - "tests": [ - { - "description": "object with property foo is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "object with property foo and bar is valid", - "data": {"foo": 1, "bar": 1}, - "valid": true - }, - { - "description": "object with any other property is invalid", - "data": {"baz": 1}, - "valid": false - }, - { - "description": "empty object is valid", - "data": {}, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/ref.json b/jsonschema/testdata/draft2020-12/ref.json deleted file mode 100644 index 0ac02fb9..00000000 --- a/jsonschema/testdata/draft2020-12/ref.json +++ /dev/null @@ -1,1052 +0,0 @@ -[ - { - "description": "root pointer ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {"$ref": "#"} - }, - "additionalProperties": false - }, - "tests": [ - { - "description": "match", - "data": {"foo": false}, - "valid": true - }, - { - "description": "recursive match", - "data": {"foo": {"foo": false}}, - "valid": true - }, - { - "description": "mismatch", - "data": {"bar": false}, - "valid": false - }, - { - "description": "recursive mismatch", - "data": {"foo": {"bar": false}}, - "valid": false - } - ] - }, - { - "description": "relative pointer ref to object", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {"type": "integer"}, - "bar": {"$ref": "#/properties/foo"} - } - }, - "tests": [ - { - "description": "match", - "data": {"bar": 3}, - "valid": true - }, - { - "description": "mismatch", - "data": {"bar": true}, - "valid": false - } - ] - }, - { - "description": "relative pointer ref to array", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - {"type": "integer"}, - {"$ref": "#/prefixItems/0"} - ] - }, - "tests": [ - { - "description": "match array", - "data": [1, 2], - "valid": true - }, - { - "description": "mismatch array", - "data": [1, "foo"], - "valid": false - } - ] - }, - { - "description": "escaped pointer ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "tilde~field": {"type": "integer"}, - "slash/field": {"type": "integer"}, - "percent%field": {"type": "integer"} - }, - "properties": { - "tilde": {"$ref": "#/$defs/tilde~0field"}, - "slash": {"$ref": "#/$defs/slash~1field"}, - "percent": {"$ref": "#/$defs/percent%25field"} - } - }, - "tests": [ - { - "description": "slash invalid", - "data": {"slash": "aoeu"}, - "valid": false - }, - { - "description": "tilde invalid", - "data": {"tilde": "aoeu"}, - "valid": false - }, - { - "description": "percent invalid", - "data": {"percent": "aoeu"}, - "valid": false - }, - { - "description": "slash valid", - "data": {"slash": 123}, - "valid": true - }, - { - "description": "tilde valid", - "data": {"tilde": 123}, - "valid": true - }, - { - "description": "percent valid", - "data": {"percent": 123}, - "valid": true - } - ] - }, - { - "description": "nested refs", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "a": {"type": "integer"}, - "b": {"$ref": "#/$defs/a"}, - "c": {"$ref": "#/$defs/b"} - }, - "$ref": "#/$defs/c" - }, - "tests": [ - { - "description": "nested ref valid", - "data": 5, - "valid": true - }, - { - "description": "nested ref invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "ref applies alongside sibling keywords", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "reffed": { - "type": "array" - } - }, - "properties": { - "foo": { - "$ref": "#/$defs/reffed", - "maxItems": 2 - } - } - }, - "tests": [ - { - "description": "ref valid, maxItems valid", - "data": { "foo": [] }, - "valid": true - }, - { - "description": "ref valid, maxItems invalid", - "data": { "foo": [1, 2, 3] }, - "valid": false - }, - { - "description": "ref invalid", - "data": { "foo": "string" }, - "valid": false - } - ] - }, - { - "description": "remote ref, containing refs itself", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "https://json-schema.org/draft/2020-12/schema" - }, - "tests": [ - { - "description": "remote ref valid", - "data": {"minLength": 1}, - "valid": true - }, - { - "description": "remote ref invalid", - "data": {"minLength": -1}, - "valid": false - } - ] - }, - { - "description": "property named $ref that is not a reference", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "$ref": {"type": "string"} - } - }, - "tests": [ - { - "description": "property named $ref valid", - "data": {"$ref": "a"}, - "valid": true - }, - { - "description": "property named $ref invalid", - "data": {"$ref": 2}, - "valid": false - } - ] - }, - { - "description": "property named $ref, containing an actual $ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "$ref": {"$ref": "#/$defs/is-string"} - }, - "$defs": { - "is-string": { - "type": "string" - } - } - }, - "tests": [ - { - "description": "property named $ref valid", - "data": {"$ref": "a"}, - "valid": true - }, - { - "description": "property named $ref invalid", - "data": {"$ref": 2}, - "valid": false - } - ] - }, - { - "description": "$ref to boolean schema true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "#/$defs/bool", - "$defs": { - "bool": true - } - }, - "tests": [ - { - "description": "any value is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "$ref to boolean schema false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "#/$defs/bool", - "$defs": { - "bool": false - } - }, - "tests": [ - { - "description": "any value is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "Recursive references between schemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/tree", - "description": "tree of nodes", - "type": "object", - "properties": { - "meta": {"type": "string"}, - "nodes": { - "type": "array", - "items": {"$ref": "node"} - } - }, - "required": ["meta", "nodes"], - "$defs": { - "node": { - "$id": "http://localhost:1234/draft2020-12/node", - "description": "node", - "type": "object", - "properties": { - "value": {"type": "number"}, - "subtree": {"$ref": "tree"} - }, - "required": ["value"] - } - } - }, - "tests": [ - { - "description": "valid tree", - "data": { - "meta": "root", - "nodes": [ - { - "value": 1, - "subtree": { - "meta": "child", - "nodes": [ - {"value": 1.1}, - {"value": 1.2} - ] - } - }, - { - "value": 2, - "subtree": { - "meta": "child", - "nodes": [ - {"value": 2.1}, - {"value": 2.2} - ] - } - } - ] - }, - "valid": true - }, - { - "description": "invalid tree", - "data": { - "meta": "root", - "nodes": [ - { - "value": 1, - "subtree": { - "meta": "child", - "nodes": [ - {"value": "string is invalid"}, - {"value": 1.2} - ] - } - }, - { - "value": 2, - "subtree": { - "meta": "child", - "nodes": [ - {"value": 2.1}, - {"value": 2.2} - ] - } - } - ] - }, - "valid": false - } - ] - }, - { - "description": "refs with quote", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo\"bar": {"$ref": "#/$defs/foo%22bar"} - }, - "$defs": { - "foo\"bar": {"type": "number"} - } - }, - "tests": [ - { - "description": "object with numbers is valid", - "data": { - "foo\"bar": 1 - }, - "valid": true - }, - { - "description": "object with strings is invalid", - "data": { - "foo\"bar": "1" - }, - "valid": false - } - ] - }, - { - "description": "ref creates new scope when adjacent to keywords", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "A": { - "unevaluatedProperties": false - } - }, - "properties": { - "prop1": { - "type": "string" - } - }, - "$ref": "#/$defs/A" - }, - "tests": [ - { - "description": "referenced subschema doesn't see annotations from properties", - "data": { - "prop1": "match" - }, - "valid": false - } - ] - }, - { - "description": "naive replacement of $ref with its destination is not correct", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "a_string": { "type": "string" } - }, - "enum": [ - { "$ref": "#/$defs/a_string" } - ] - }, - "tests": [ - { - "description": "do not evaluate the $ref inside the enum, matching any string", - "data": "this is a string", - "valid": false - }, - { - "description": "do not evaluate the $ref inside the enum, definition exact match", - "data": { "type": "string" }, - "valid": false - }, - { - "description": "match the enum exactly", - "data": { "$ref": "#/$defs/a_string" }, - "valid": true - } - ] - }, - { - "description": "refs with relative uris and defs", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://example.com/schema-relative-uri-defs1.json", - "properties": { - "foo": { - "$id": "schema-relative-uri-defs2.json", - "$defs": { - "inner": { - "properties": { - "bar": { "type": "string" } - } - } - }, - "$ref": "#/$defs/inner" - } - }, - "$ref": "schema-relative-uri-defs2.json" - }, - "tests": [ - { - "description": "invalid on inner field", - "data": { - "foo": { - "bar": 1 - }, - "bar": "a" - }, - "valid": false - }, - { - "description": "invalid on outer field", - "data": { - "foo": { - "bar": "a" - }, - "bar": 1 - }, - "valid": false - }, - { - "description": "valid on both fields", - "data": { - "foo": { - "bar": "a" - }, - "bar": "a" - }, - "valid": true - } - ] - }, - { - "description": "relative refs with absolute uris and defs", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://example.com/schema-refs-absolute-uris-defs1.json", - "properties": { - "foo": { - "$id": "http://example.com/schema-refs-absolute-uris-defs2.json", - "$defs": { - "inner": { - "properties": { - "bar": { "type": "string" } - } - } - }, - "$ref": "#/$defs/inner" - } - }, - "$ref": "schema-refs-absolute-uris-defs2.json" - }, - "tests": [ - { - "description": "invalid on inner field", - "data": { - "foo": { - "bar": 1 - }, - "bar": "a" - }, - "valid": false - }, - { - "description": "invalid on outer field", - "data": { - "foo": { - "bar": "a" - }, - "bar": 1 - }, - "valid": false - }, - { - "description": "valid on both fields", - "data": { - "foo": { - "bar": "a" - }, - "bar": "a" - }, - "valid": true - } - ] - }, - { - "description": "$id must be resolved against nearest parent, not just immediate parent", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://example.com/a.json", - "$defs": { - "x": { - "$id": "http://example.com/b/c.json", - "not": { - "$defs": { - "y": { - "$id": "d.json", - "type": "number" - } - } - } - } - }, - "allOf": [ - { - "$ref": "http://example.com/b/d.json" - } - ] - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "non-number is invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "order of evaluation: $id and $ref", - "schema": { - "$comment": "$id must be evaluated before $ref to get the proper $ref destination", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://example.com/draft2020-12/ref-and-id1/base.json", - "$ref": "int.json", - "$defs": { - "bigint": { - "$comment": "canonical uri: https://example.com/ref-and-id1/int.json", - "$id": "int.json", - "maximum": 10 - }, - "smallint": { - "$comment": "canonical uri: https://example.com/ref-and-id1-int.json", - "$id": "/draft2020-12/ref-and-id1-int.json", - "maximum": 2 - } - } - }, - "tests": [ - { - "description": "data is valid against first definition", - "data": 5, - "valid": true - }, - { - "description": "data is invalid against first definition", - "data": 50, - "valid": false - } - ] - }, - { - "description": "order of evaluation: $id and $anchor and $ref", - "schema": { - "$comment": "$id must be evaluated before $ref to get the proper $ref destination", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://example.com/draft2020-12/ref-and-id2/base.json", - "$ref": "#bigint", - "$defs": { - "bigint": { - "$comment": "canonical uri: /ref-and-id2/base.json#/$defs/bigint; another valid uri for this location: /ref-and-id2/base.json#bigint", - "$anchor": "bigint", - "maximum": 10 - }, - "smallint": { - "$comment": "canonical uri: https://example.com/ref-and-id2#/$defs/smallint; another valid uri for this location: https://example.com/ref-and-id2/#bigint", - "$id": "https://example.com/draft2020-12/ref-and-id2/", - "$anchor": "bigint", - "maximum": 2 - } - } - }, - "tests": [ - { - "description": "data is valid against first definition", - "data": 5, - "valid": true - }, - { - "description": "data is invalid against first definition", - "data": 50, - "valid": false - } - ] - }, - { - "description": "simple URN base URI with $ref via the URN", - "schema": { - "$comment": "URIs do not have to have HTTP(s) schemes", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "urn:uuid:deadbeef-1234-ffff-ffff-4321feebdaed", - "minimum": 30, - "properties": { - "foo": {"$ref": "urn:uuid:deadbeef-1234-ffff-ffff-4321feebdaed"} - } - }, - "tests": [ - { - "description": "valid under the URN IDed schema", - "data": {"foo": 37}, - "valid": true - }, - { - "description": "invalid under the URN IDed schema", - "data": {"foo": 12}, - "valid": false - } - ] - }, - { - "description": "simple URN base URI with JSON pointer", - "schema": { - "$comment": "URIs do not have to have HTTP(s) schemes", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "urn:uuid:deadbeef-1234-00ff-ff00-4321feebdaed", - "properties": { - "foo": {"$ref": "#/$defs/bar"} - }, - "$defs": { - "bar": {"type": "string"} - } - }, - "tests": [ - { - "description": "a string is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "a non-string is invalid", - "data": {"foo": 12}, - "valid": false - } - ] - }, - { - "description": "URN base URI with NSS", - "schema": { - "$comment": "RFC 8141 §2.2", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "urn:example:1/406/47452/2", - "properties": { - "foo": {"$ref": "#/$defs/bar"} - }, - "$defs": { - "bar": {"type": "string"} - } - }, - "tests": [ - { - "description": "a string is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "a non-string is invalid", - "data": {"foo": 12}, - "valid": false - } - ] - }, - { - "description": "URN base URI with r-component", - "schema": { - "$comment": "RFC 8141 §2.3.1", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "urn:example:foo-bar-baz-qux?+CCResolve:cc=uk", - "properties": { - "foo": {"$ref": "#/$defs/bar"} - }, - "$defs": { - "bar": {"type": "string"} - } - }, - "tests": [ - { - "description": "a string is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "a non-string is invalid", - "data": {"foo": 12}, - "valid": false - } - ] - }, - { - "description": "URN base URI with q-component", - "schema": { - "$comment": "RFC 8141 §2.3.2", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "urn:example:weather?=op=map&lat=39.56&lon=-104.85&datetime=1969-07-21T02:56:15Z", - "properties": { - "foo": {"$ref": "#/$defs/bar"} - }, - "$defs": { - "bar": {"type": "string"} - } - }, - "tests": [ - { - "description": "a string is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "a non-string is invalid", - "data": {"foo": 12}, - "valid": false - } - ] - }, - { - "description": "URN base URI with URN and JSON pointer ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "urn:uuid:deadbeef-1234-0000-0000-4321feebdaed", - "properties": { - "foo": {"$ref": "urn:uuid:deadbeef-1234-0000-0000-4321feebdaed#/$defs/bar"} - }, - "$defs": { - "bar": {"type": "string"} - } - }, - "tests": [ - { - "description": "a string is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "a non-string is invalid", - "data": {"foo": 12}, - "valid": false - } - ] - }, - { - "description": "URN base URI with URN and anchor ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "urn:uuid:deadbeef-1234-ff00-00ff-4321feebdaed", - "properties": { - "foo": {"$ref": "urn:uuid:deadbeef-1234-ff00-00ff-4321feebdaed#something"} - }, - "$defs": { - "bar": { - "$anchor": "something", - "type": "string" - } - } - }, - "tests": [ - { - "description": "a string is valid", - "data": {"foo": "bar"}, - "valid": true - }, - { - "description": "a non-string is invalid", - "data": {"foo": 12}, - "valid": false - } - ] - }, - { - "description": "URN ref with nested pointer ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "urn:uuid:deadbeef-4321-ffff-ffff-1234feebdaed", - "$defs": { - "foo": { - "$id": "urn:uuid:deadbeef-4321-ffff-ffff-1234feebdaed", - "$defs": {"bar": {"type": "string"}}, - "$ref": "#/$defs/bar" - } - } - }, - "tests": [ - { - "description": "a string is valid", - "data": "bar", - "valid": true - }, - { - "description": "a non-string is invalid", - "data": 12, - "valid": false - } - ] - }, - { - "description": "ref to if", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://example.com/ref/if", - "if": { - "$id": "http://example.com/ref/if", - "type": "integer" - } - }, - "tests": [ - { - "description": "a non-integer is invalid due to the $ref", - "data": "foo", - "valid": false - }, - { - "description": "an integer is valid", - "data": 12, - "valid": true - } - ] - }, - { - "description": "ref to then", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://example.com/ref/then", - "then": { - "$id": "http://example.com/ref/then", - "type": "integer" - } - }, - "tests": [ - { - "description": "a non-integer is invalid due to the $ref", - "data": "foo", - "valid": false - }, - { - "description": "an integer is valid", - "data": 12, - "valid": true - } - ] - }, - { - "description": "ref to else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://example.com/ref/else", - "else": { - "$id": "http://example.com/ref/else", - "type": "integer" - } - }, - "tests": [ - { - "description": "a non-integer is invalid due to the $ref", - "data": "foo", - "valid": false - }, - { - "description": "an integer is valid", - "data": 12, - "valid": true - } - ] - }, - { - "description": "ref with absolute-path-reference", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://example.com/ref/absref.json", - "$defs": { - "a": { - "$id": "http://example.com/ref/absref/foobar.json", - "type": "number" - }, - "b": { - "$id": "http://example.com/absref/foobar.json", - "type": "string" - } - }, - "$ref": "/absref/foobar.json" - }, - "tests": [ - { - "description": "a string is valid", - "data": "foo", - "valid": true - }, - { - "description": "an integer is invalid", - "data": 12, - "valid": false - } - ] - }, - { - "description": "$id with file URI still resolves pointers - *nix", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "file:///folder/file.json", - "$defs": { - "foo": { - "type": "number" - } - }, - "$ref": "#/$defs/foo" - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "non-number is invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "$id with file URI still resolves pointers - windows", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "file:///c:/folder/file.json", - "$defs": { - "foo": { - "type": "number" - } - }, - "$ref": "#/$defs/foo" - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "non-number is invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "empty tokens in $ref json-pointer", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "": { - "$defs": { - "": { "type": "number" } - } - } - }, - "allOf": [ - { - "$ref": "#/$defs//$defs/" - } - ] - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "non-number is invalid", - "data": "a", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/refRemote.json b/jsonschema/testdata/draft2020-12/refRemote.json deleted file mode 100644 index 047ac74c..00000000 --- a/jsonschema/testdata/draft2020-12/refRemote.json +++ /dev/null @@ -1,342 +0,0 @@ -[ - { - "description": "remote ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/draft2020-12/integer.json" - }, - "tests": [ - { - "description": "remote ref valid", - "data": 1, - "valid": true - }, - { - "description": "remote ref invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "fragment within remote ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/draft2020-12/subSchemas.json#/$defs/integer" - }, - "tests": [ - { - "description": "remote fragment valid", - "data": 1, - "valid": true - }, - { - "description": "remote fragment invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "anchor within remote ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/draft2020-12/locationIndependentIdentifier.json#foo" - }, - "tests": [ - { - "description": "remote anchor valid", - "data": 1, - "valid": true - }, - { - "description": "remote anchor invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "ref within remote ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/draft2020-12/subSchemas.json#/$defs/refToInteger" - }, - "tests": [ - { - "description": "ref within ref valid", - "data": 1, - "valid": true - }, - { - "description": "ref within ref invalid", - "data": "a", - "valid": false - } - ] - }, - { - "description": "base URI change", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/", - "items": { - "$id": "baseUriChange/", - "items": {"$ref": "folderInteger.json"} - } - }, - "tests": [ - { - "description": "base URI change ref valid", - "data": [[1]], - "valid": true - }, - { - "description": "base URI change ref invalid", - "data": [["a"]], - "valid": false - } - ] - }, - { - "description": "base URI change - change folder", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/scope_change_defs1.json", - "type" : "object", - "properties": {"list": {"$ref": "baseUriChangeFolder/"}}, - "$defs": { - "baz": { - "$id": "baseUriChangeFolder/", - "type": "array", - "items": {"$ref": "folderInteger.json"} - } - } - }, - "tests": [ - { - "description": "number is valid", - "data": {"list": [1]}, - "valid": true - }, - { - "description": "string is invalid", - "data": {"list": ["a"]}, - "valid": false - } - ] - }, - { - "description": "base URI change - change folder in subschema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/scope_change_defs2.json", - "type" : "object", - "properties": {"list": {"$ref": "baseUriChangeFolderInSubschema/#/$defs/bar"}}, - "$defs": { - "baz": { - "$id": "baseUriChangeFolderInSubschema/", - "$defs": { - "bar": { - "type": "array", - "items": {"$ref": "folderInteger.json"} - } - } - } - } - }, - "tests": [ - { - "description": "number is valid", - "data": {"list": [1]}, - "valid": true - }, - { - "description": "string is invalid", - "data": {"list": ["a"]}, - "valid": false - } - ] - }, - { - "description": "root ref in remote ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/object", - "type": "object", - "properties": { - "name": {"$ref": "name-defs.json#/$defs/orNull"} - } - }, - "tests": [ - { - "description": "string is valid", - "data": { - "name": "foo" - }, - "valid": true - }, - { - "description": "null is valid", - "data": { - "name": null - }, - "valid": true - }, - { - "description": "object is invalid", - "data": { - "name": { - "name": null - } - }, - "valid": false - } - ] - }, - { - "description": "remote ref with ref to defs", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/schema-remote-ref-ref-defs1.json", - "$ref": "ref-and-defs.json" - }, - "tests": [ - { - "description": "invalid", - "data": { - "bar": 1 - }, - "valid": false - }, - { - "description": "valid", - "data": { - "bar": "a" - }, - "valid": true - } - ] - }, - { - "description": "Location-independent identifier in remote ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/draft2020-12/locationIndependentIdentifier.json#/$defs/refToInteger" - }, - "tests": [ - { - "description": "integer is valid", - "data": 1, - "valid": true - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - } - ] - }, - { - "description": "retrieved nested refs resolve relative to their URI not $id", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/some-id", - "properties": { - "name": {"$ref": "nested/foo-ref-string.json"} - } - }, - "tests": [ - { - "description": "number is invalid", - "data": { - "name": {"foo": 1} - }, - "valid": false - }, - { - "description": "string is valid", - "data": { - "name": {"foo": "a"} - }, - "valid": true - } - ] - }, - { - "description": "remote HTTP ref with different $id", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/different-id-ref-string.json" - }, - "tests": [ - { - "description": "number is invalid", - "data": 1, - "valid": false - }, - { - "description": "string is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "remote HTTP ref with different URN $id", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/urn-ref-string.json" - }, - "tests": [ - { - "description": "number is invalid", - "data": 1, - "valid": false - }, - { - "description": "string is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "remote HTTP ref with nested absolute ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/nested-absolute-ref-to-string.json" - }, - "tests": [ - { - "description": "number is invalid", - "data": 1, - "valid": false - }, - { - "description": "string is valid", - "data": "foo", - "valid": true - } - ] - }, - { - "description": "$ref to $ref finds detached $anchor", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "http://localhost:1234/draft2020-12/detached-ref.json#/$defs/foo" - }, - "tests": [ - { - "description": "number is valid", - "data": 1, - "valid": true - }, - { - "description": "non-number is invalid", - "data": "a", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/required.json b/jsonschema/testdata/draft2020-12/required.json deleted file mode 100644 index e66f29f2..00000000 --- a/jsonschema/testdata/draft2020-12/required.json +++ /dev/null @@ -1,158 +0,0 @@ -[ - { - "description": "required validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {}, - "bar": {} - }, - "required": ["foo"] - }, - "tests": [ - { - "description": "present required property is valid", - "data": {"foo": 1}, - "valid": true - }, - { - "description": "non-present required property is invalid", - "data": {"bar": 1}, - "valid": false - }, - { - "description": "ignores arrays", - "data": [], - "valid": true - }, - { - "description": "ignores strings", - "data": "", - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - } - ] - }, - { - "description": "required default validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {} - } - }, - "tests": [ - { - "description": "not required by default", - "data": {}, - "valid": true - } - ] - }, - { - "description": "required with empty array", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": {} - }, - "required": [] - }, - "tests": [ - { - "description": "property not required", - "data": {}, - "valid": true - } - ] - }, - { - "description": "required with escaped characters", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "required": [ - "foo\nbar", - "foo\"bar", - "foo\\bar", - "foo\rbar", - "foo\tbar", - "foo\fbar" - ] - }, - "tests": [ - { - "description": "object with all properties present is valid", - "data": { - "foo\nbar": 1, - "foo\"bar": 1, - "foo\\bar": 1, - "foo\rbar": 1, - "foo\tbar": 1, - "foo\fbar": 1 - }, - "valid": true - }, - { - "description": "object with some properties missing is invalid", - "data": { - "foo\nbar": "1", - "foo\"bar": "1" - }, - "valid": false - } - ] - }, - { - "description": "required properties whose names are Javascript object property names", - "comment": "Ensure JS implementations don't universally consider e.g. __proto__ to always be present in an object.", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "required": ["__proto__", "toString", "constructor"] - }, - "tests": [ - { - "description": "ignores arrays", - "data": [], - "valid": true - }, - { - "description": "ignores other non-objects", - "data": 12, - "valid": true - }, - { - "description": "none of the properties mentioned", - "data": {}, - "valid": false - }, - { - "description": "__proto__ present", - "data": { "__proto__": "foo" }, - "valid": false - }, - { - "description": "toString present", - "data": { "toString": { "length": 37 } }, - "valid": false - }, - { - "description": "constructor present", - "data": { "constructor": { "length": 37 } }, - "valid": false - }, - { - "description": "all present", - "data": { - "__proto__": 12, - "toString": { "length": "foo" }, - "constructor": 37 - }, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/type.json b/jsonschema/testdata/draft2020-12/type.json deleted file mode 100644 index 2123c408..00000000 --- a/jsonschema/testdata/draft2020-12/type.json +++ /dev/null @@ -1,501 +0,0 @@ -[ - { - "description": "integer type matches integers", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "integer" - }, - "tests": [ - { - "description": "an integer is an integer", - "data": 1, - "valid": true - }, - { - "description": "a float with zero fractional part is an integer", - "data": 1.0, - "valid": true - }, - { - "description": "a float is not an integer", - "data": 1.1, - "valid": false - }, - { - "description": "a string is not an integer", - "data": "foo", - "valid": false - }, - { - "description": "a string is still not an integer, even if it looks like one", - "data": "1", - "valid": false - }, - { - "description": "an object is not an integer", - "data": {}, - "valid": false - }, - { - "description": "an array is not an integer", - "data": [], - "valid": false - }, - { - "description": "a boolean is not an integer", - "data": true, - "valid": false - }, - { - "description": "null is not an integer", - "data": null, - "valid": false - } - ] - }, - { - "description": "number type matches numbers", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "number" - }, - "tests": [ - { - "description": "an integer is a number", - "data": 1, - "valid": true - }, - { - "description": "a float with zero fractional part is a number (and an integer)", - "data": 1.0, - "valid": true - }, - { - "description": "a float is a number", - "data": 1.1, - "valid": true - }, - { - "description": "a string is not a number", - "data": "foo", - "valid": false - }, - { - "description": "a string is still not a number, even if it looks like one", - "data": "1", - "valid": false - }, - { - "description": "an object is not a number", - "data": {}, - "valid": false - }, - { - "description": "an array is not a number", - "data": [], - "valid": false - }, - { - "description": "a boolean is not a number", - "data": true, - "valid": false - }, - { - "description": "null is not a number", - "data": null, - "valid": false - } - ] - }, - { - "description": "string type matches strings", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "string" - }, - "tests": [ - { - "description": "1 is not a string", - "data": 1, - "valid": false - }, - { - "description": "a float is not a string", - "data": 1.1, - "valid": false - }, - { - "description": "a string is a string", - "data": "foo", - "valid": true - }, - { - "description": "a string is still a string, even if it looks like a number", - "data": "1", - "valid": true - }, - { - "description": "an empty string is still a string", - "data": "", - "valid": true - }, - { - "description": "an object is not a string", - "data": {}, - "valid": false - }, - { - "description": "an array is not a string", - "data": [], - "valid": false - }, - { - "description": "a boolean is not a string", - "data": true, - "valid": false - }, - { - "description": "null is not a string", - "data": null, - "valid": false - } - ] - }, - { - "description": "object type matches objects", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object" - }, - "tests": [ - { - "description": "an integer is not an object", - "data": 1, - "valid": false - }, - { - "description": "a float is not an object", - "data": 1.1, - "valid": false - }, - { - "description": "a string is not an object", - "data": "foo", - "valid": false - }, - { - "description": "an object is an object", - "data": {}, - "valid": true - }, - { - "description": "an array is not an object", - "data": [], - "valid": false - }, - { - "description": "a boolean is not an object", - "data": true, - "valid": false - }, - { - "description": "null is not an object", - "data": null, - "valid": false - } - ] - }, - { - "description": "array type matches arrays", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "array" - }, - "tests": [ - { - "description": "an integer is not an array", - "data": 1, - "valid": false - }, - { - "description": "a float is not an array", - "data": 1.1, - "valid": false - }, - { - "description": "a string is not an array", - "data": "foo", - "valid": false - }, - { - "description": "an object is not an array", - "data": {}, - "valid": false - }, - { - "description": "an array is an array", - "data": [], - "valid": true - }, - { - "description": "a boolean is not an array", - "data": true, - "valid": false - }, - { - "description": "null is not an array", - "data": null, - "valid": false - } - ] - }, - { - "description": "boolean type matches booleans", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "boolean" - }, - "tests": [ - { - "description": "an integer is not a boolean", - "data": 1, - "valid": false - }, - { - "description": "zero is not a boolean", - "data": 0, - "valid": false - }, - { - "description": "a float is not a boolean", - "data": 1.1, - "valid": false - }, - { - "description": "a string is not a boolean", - "data": "foo", - "valid": false - }, - { - "description": "an empty string is not a boolean", - "data": "", - "valid": false - }, - { - "description": "an object is not a boolean", - "data": {}, - "valid": false - }, - { - "description": "an array is not a boolean", - "data": [], - "valid": false - }, - { - "description": "true is a boolean", - "data": true, - "valid": true - }, - { - "description": "false is a boolean", - "data": false, - "valid": true - }, - { - "description": "null is not a boolean", - "data": null, - "valid": false - } - ] - }, - { - "description": "null type matches only the null object", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "null" - }, - "tests": [ - { - "description": "an integer is not null", - "data": 1, - "valid": false - }, - { - "description": "a float is not null", - "data": 1.1, - "valid": false - }, - { - "description": "zero is not null", - "data": 0, - "valid": false - }, - { - "description": "a string is not null", - "data": "foo", - "valid": false - }, - { - "description": "an empty string is not null", - "data": "", - "valid": false - }, - { - "description": "an object is not null", - "data": {}, - "valid": false - }, - { - "description": "an array is not null", - "data": [], - "valid": false - }, - { - "description": "true is not null", - "data": true, - "valid": false - }, - { - "description": "false is not null", - "data": false, - "valid": false - }, - { - "description": "null is null", - "data": null, - "valid": true - } - ] - }, - { - "description": "multiple types can be specified in an array", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": ["integer", "string"] - }, - "tests": [ - { - "description": "an integer is valid", - "data": 1, - "valid": true - }, - { - "description": "a string is valid", - "data": "foo", - "valid": true - }, - { - "description": "a float is invalid", - "data": 1.1, - "valid": false - }, - { - "description": "an object is invalid", - "data": {}, - "valid": false - }, - { - "description": "an array is invalid", - "data": [], - "valid": false - }, - { - "description": "a boolean is invalid", - "data": true, - "valid": false - }, - { - "description": "null is invalid", - "data": null, - "valid": false - } - ] - }, - { - "description": "type as array with one item", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": ["string"] - }, - "tests": [ - { - "description": "string is valid", - "data": "foo", - "valid": true - }, - { - "description": "number is invalid", - "data": 123, - "valid": false - } - ] - }, - { - "description": "type: array or object", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": ["array", "object"] - }, - "tests": [ - { - "description": "array is valid", - "data": [1,2,3], - "valid": true - }, - { - "description": "object is valid", - "data": {"foo": 123}, - "valid": true - }, - { - "description": "number is invalid", - "data": 123, - "valid": false - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - }, - { - "description": "null is invalid", - "data": null, - "valid": false - } - ] - }, - { - "description": "type: array, object or null", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": ["array", "object", "null"] - }, - "tests": [ - { - "description": "array is valid", - "data": [1,2,3], - "valid": true - }, - { - "description": "object is valid", - "data": {"foo": 123}, - "valid": true - }, - { - "description": "null is valid", - "data": null, - "valid": true - }, - { - "description": "number is invalid", - "data": 123, - "valid": false - }, - { - "description": "string is invalid", - "data": "foo", - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/unevaluatedItems.json b/jsonschema/testdata/draft2020-12/unevaluatedItems.json deleted file mode 100644 index f861cefa..00000000 --- a/jsonschema/testdata/draft2020-12/unevaluatedItems.json +++ /dev/null @@ -1,798 +0,0 @@ -[ - { - "description": "unevaluatedItems true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedItems": true - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": [], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo"], - "valid": true - } - ] - }, - { - "description": "unevaluatedItems false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": [], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems as schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedItems": { "type": "string" } - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": [], - "valid": true - }, - { - "description": "with valid unevaluated items", - "data": ["foo"], - "valid": true - }, - { - "description": "with invalid unevaluated items", - "data": [42], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with uniform items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "items": { "type": "string" }, - "unevaluatedItems": false - }, - "tests": [ - { - "description": "unevaluatedItems doesn't apply", - "data": ["foo", "bar"], - "valid": true - } - ] - }, - { - "description": "unevaluatedItems with tuple", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { "type": "string" } - ], - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": ["foo"], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo", "bar"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with items and prefixItems", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { "type": "string" } - ], - "items": true, - "unevaluatedItems": false - }, - "tests": [ - { - "description": "unevaluatedItems doesn't apply", - "data": ["foo", 42], - "valid": true - } - ] - }, - { - "description": "unevaluatedItems with items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "items": {"type": "number"}, - "unevaluatedItems": {"type": "string"} - }, - "tests": [ - { - "description": "valid under items", - "comment": "no elements are considered by unevaluatedItems", - "data": [5, 6, 7, 8], - "valid": true - }, - { - "description": "invalid under items", - "data": ["foo", "bar", "baz"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with nested tuple", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { "type": "string" } - ], - "allOf": [ - { - "prefixItems": [ - true, - { "type": "number" } - ] - } - ], - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": ["foo", 42], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo", 42, true], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with nested items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedItems": {"type": "boolean"}, - "anyOf": [ - { "items": {"type": "string"} }, - true - ] - }, - "tests": [ - { - "description": "with only (valid) additional items", - "data": [true, false], - "valid": true - }, - { - "description": "with no additional items", - "data": ["yes", "no"], - "valid": true - }, - { - "description": "with invalid additional item", - "data": ["yes", false], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with nested prefixItems and items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "prefixItems": [ - { "type": "string" } - ], - "items": true - } - ], - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with no additional items", - "data": ["foo"], - "valid": true - }, - { - "description": "with additional items", - "data": ["foo", 42, true], - "valid": true - } - ] - }, - { - "description": "unevaluatedItems with nested unevaluatedItems", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "prefixItems": [ - { "type": "string" } - ] - }, - { "unevaluatedItems": true } - ], - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with no additional items", - "data": ["foo"], - "valid": true - }, - { - "description": "with additional items", - "data": ["foo", 42, true], - "valid": true - } - ] - }, - { - "description": "unevaluatedItems with anyOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { "const": "foo" } - ], - "anyOf": [ - { - "prefixItems": [ - true, - { "const": "bar" } - ] - }, - { - "prefixItems": [ - true, - true, - { "const": "baz" } - ] - } - ], - "unevaluatedItems": false - }, - "tests": [ - { - "description": "when one schema matches and has no unevaluated items", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "when one schema matches and has unevaluated items", - "data": ["foo", "bar", 42], - "valid": false - }, - { - "description": "when two schemas match and has no unevaluated items", - "data": ["foo", "bar", "baz"], - "valid": true - }, - { - "description": "when two schemas match and has unevaluated items", - "data": ["foo", "bar", "baz", 42], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with oneOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { "const": "foo" } - ], - "oneOf": [ - { - "prefixItems": [ - true, - { "const": "bar" } - ] - }, - { - "prefixItems": [ - true, - { "const": "baz" } - ] - } - ], - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo", "bar", 42], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with not", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { "const": "foo" } - ], - "not": { - "not": { - "prefixItems": [ - true, - { "const": "bar" } - ] - } - }, - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with unevaluated items", - "data": ["foo", "bar"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with if/then/else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - { "const": "foo" } - ], - "if": { - "prefixItems": [ - true, - { "const": "bar" } - ] - }, - "then": { - "prefixItems": [ - true, - true, - { "const": "then" } - ] - }, - "else": { - "prefixItems": [ - true, - true, - true, - { "const": "else" } - ] - }, - "unevaluatedItems": false - }, - "tests": [ - { - "description": "when if matches and it has no unevaluated items", - "data": ["foo", "bar", "then"], - "valid": true - }, - { - "description": "when if matches and it has unevaluated items", - "data": ["foo", "bar", "then", "else"], - "valid": false - }, - { - "description": "when if doesn't match and it has no unevaluated items", - "data": ["foo", 42, 42, "else"], - "valid": true - }, - { - "description": "when if doesn't match and it has unevaluated items", - "data": ["foo", 42, 42, "else", 42], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with boolean schemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [true], - "unevaluatedItems": false - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": [], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with $ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$ref": "#/$defs/bar", - "prefixItems": [ - { "type": "string" } - ], - "unevaluatedItems": false, - "$defs": { - "bar": { - "prefixItems": [ - true, - { "type": "string" } - ] - } - } - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo", "bar", "baz"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems before $ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedItems": false, - "prefixItems": [ - { "type": "string" } - ], - "$ref": "#/$defs/bar", - "$defs": { - "bar": { - "prefixItems": [ - true, - { "type": "string" } - ] - } - } - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo", "bar", "baz"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems with $dynamicRef", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://example.com/unevaluated-items-with-dynamic-ref/derived", - - "$ref": "./baseSchema", - - "$defs": { - "derived": { - "$dynamicAnchor": "addons", - "prefixItems": [ - true, - { "type": "string" } - ] - }, - "baseSchema": { - "$id": "./baseSchema", - - "$comment": "unevaluatedItems comes first so it's more likely to catch bugs with implementations that are sensitive to keyword ordering", - "unevaluatedItems": false, - "type": "array", - "prefixItems": [ - { "type": "string" } - ], - "$dynamicRef": "#addons", - - "$defs": { - "defaultAddons": { - "$comment": "Needed to satisfy the bookending requirement", - "$dynamicAnchor": "addons" - } - } - } - } - }, - "tests": [ - { - "description": "with no unevaluated items", - "data": ["foo", "bar"], - "valid": true - }, - { - "description": "with unevaluated items", - "data": ["foo", "bar", "baz"], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems can't see inside cousins", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "prefixItems": [ true ] - }, - { "unevaluatedItems": false } - ] - }, - "tests": [ - { - "description": "always fails", - "data": [ 1 ], - "valid": false - } - ] - }, - { - "description": "item is evaluated in an uncle schema to unevaluatedItems", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": { - "foo": { - "prefixItems": [ - { "type": "string" } - ], - "unevaluatedItems": false - } - }, - "anyOf": [ - { - "properties": { - "foo": { - "prefixItems": [ - true, - { "type": "string" } - ] - } - } - } - ] - }, - "tests": [ - { - "description": "no extra items", - "data": { - "foo": [ - "test" - ] - }, - "valid": true - }, - { - "description": "uncle keyword evaluation is not significant", - "data": { - "foo": [ - "test", - "test" - ] - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedItems depends on adjacent contains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [true], - "contains": {"type": "string"}, - "unevaluatedItems": false - }, - "tests": [ - { - "description": "second item is evaluated by contains", - "data": [ 1, "foo" ], - "valid": true - }, - { - "description": "contains fails, second item is not evaluated", - "data": [ 1, 2 ], - "valid": false - }, - { - "description": "contains passes, second item is not evaluated", - "data": [ 1, 2, "foo" ], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems depends on multiple nested contains", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { "contains": { "multipleOf": 2 } }, - { "contains": { "multipleOf": 3 } } - ], - "unevaluatedItems": { "multipleOf": 5 } - }, - "tests": [ - { - "description": "5 not evaluated, passes unevaluatedItems", - "data": [ 2, 3, 4, 5, 6 ], - "valid": true - }, - { - "description": "7 not evaluated, fails unevaluatedItems", - "data": [ 2, 3, 4, 7, 8 ], - "valid": false - } - ] - }, - { - "description": "unevaluatedItems and contains interact to control item dependency relationship", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": { - "contains": {"const": "a"} - }, - "then": { - "if": { - "contains": {"const": "b"} - }, - "then": { - "if": { - "contains": {"const": "c"} - } - } - }, - "unevaluatedItems": false - }, - "tests": [ - { - "description": "empty array is valid", - "data": [], - "valid": true - }, - { - "description": "only a's are valid", - "data": [ "a", "a" ], - "valid": true - }, - { - "description": "a's and b's are valid", - "data": [ "a", "b", "a", "b", "a" ], - "valid": true - }, - { - "description": "a's, b's and c's are valid", - "data": [ "c", "a", "c", "c", "b", "a" ], - "valid": true - }, - { - "description": "only b's are invalid", - "data": [ "b", "b" ], - "valid": false - }, - { - "description": "only c's are invalid", - "data": [ "c", "c" ], - "valid": false - }, - { - "description": "only b's and c's are invalid", - "data": [ "c", "b", "c", "b", "c" ], - "valid": false - }, - { - "description": "only a's and c's are invalid", - "data": [ "c", "a", "c", "a", "c" ], - "valid": false - } - ] - }, - { - "description": "non-array instances are valid", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedItems": false - }, - "tests": [ - { - "description": "ignores booleans", - "data": true, - "valid": true - }, - { - "description": "ignores integers", - "data": 123, - "valid": true - }, - { - "description": "ignores floats", - "data": 1.0, - "valid": true - }, - { - "description": "ignores objects", - "data": {}, - "valid": true - }, - { - "description": "ignores strings", - "data": "foo", - "valid": true - }, - { - "description": "ignores null", - "data": null, - "valid": true - } - ] - }, - { - "description": "unevaluatedItems with null instance elements", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedItems": { - "type": "null" - } - }, - "tests": [ - { - "description": "allows null elements", - "data": [ null ], - "valid": true - } - ] - }, - { - "description": "unevaluatedItems can see annotations from if without then and else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": { - "prefixItems": [{"const": "a"}] - }, - "unevaluatedItems": false - }, - "tests": [ - { - "description": "valid in case if is evaluated", - "data": [ "a" ], - "valid": true - }, - { - "description": "invalid in case if is evaluated", - "data": [ "b" ], - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/unevaluatedProperties.json b/jsonschema/testdata/draft2020-12/unevaluatedProperties.json deleted file mode 100644 index ae29c9eb..00000000 --- a/jsonschema/testdata/draft2020-12/unevaluatedProperties.json +++ /dev/null @@ -1,1601 +0,0 @@ -[ - { - "description": "unevaluatedProperties true", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "unevaluatedProperties": true - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": {}, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - } - ] - }, - { - "description": "unevaluatedProperties schema", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "unevaluatedProperties": { - "type": "string", - "minLength": 3 - } - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": {}, - "valid": true - }, - { - "description": "with valid unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with invalid unevaluated properties", - "data": { - "foo": "fo" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": {}, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with adjacent properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with adjacent patternProperties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "patternProperties": { - "^foo": { "type": "string" } - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with adjacent additionalProperties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "additionalProperties": true, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no additional properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with additional properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - } - ] - }, - { - "description": "unevaluatedProperties with nested properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "allOf": [ - { - "properties": { - "bar": { "type": "string" } - } - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no additional properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "with additional properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with nested patternProperties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "allOf": [ - { - "patternProperties": { - "^bar": { "type": "string" } - } - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no additional properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "with additional properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with nested additionalProperties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "allOf": [ - { - "additionalProperties": true - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no additional properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with additional properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - } - ] - }, - { - "description": "unevaluatedProperties with nested unevaluatedProperties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "allOf": [ - { - "unevaluatedProperties": true - } - ], - "unevaluatedProperties": { - "type": "string", - "maxLength": 2 - } - }, - "tests": [ - { - "description": "with no nested unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with nested unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - } - ] - }, - { - "description": "unevaluatedProperties with anyOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "anyOf": [ - { - "properties": { - "bar": { "const": "bar" } - }, - "required": ["bar"] - }, - { - "properties": { - "baz": { "const": "baz" } - }, - "required": ["baz"] - }, - { - "properties": { - "quux": { "const": "quux" } - }, - "required": ["quux"] - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "when one matches and has no unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "when one matches and has unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "not-baz" - }, - "valid": false - }, - { - "description": "when two match and has no unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "baz" - }, - "valid": true - }, - { - "description": "when two match and has unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "baz", - "quux": "not-quux" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with oneOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "oneOf": [ - { - "properties": { - "bar": { "const": "bar" } - }, - "required": ["bar"] - }, - { - "properties": { - "baz": { "const": "baz" } - }, - "required": ["baz"] - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar", - "quux": "quux" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with not", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "not": { - "not": { - "properties": { - "bar": { "const": "bar" } - }, - "required": ["bar"] - } - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with if/then/else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "if": { - "properties": { - "foo": { "const": "then" } - }, - "required": ["foo"] - }, - "then": { - "properties": { - "bar": { "type": "string" } - }, - "required": ["bar"] - }, - "else": { - "properties": { - "baz": { "type": "string" } - }, - "required": ["baz"] - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "when if is true and has no unevaluated properties", - "data": { - "foo": "then", - "bar": "bar" - }, - "valid": true - }, - { - "description": "when if is true and has unevaluated properties", - "data": { - "foo": "then", - "bar": "bar", - "baz": "baz" - }, - "valid": false - }, - { - "description": "when if is false and has no unevaluated properties", - "data": { - "baz": "baz" - }, - "valid": true - }, - { - "description": "when if is false and has unevaluated properties", - "data": { - "foo": "else", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with if/then/else, then not defined", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "if": { - "properties": { - "foo": { "const": "then" } - }, - "required": ["foo"] - }, - "else": { - "properties": { - "baz": { "type": "string" } - }, - "required": ["baz"] - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "when if is true and has no unevaluated properties", - "data": { - "foo": "then", - "bar": "bar" - }, - "valid": false - }, - { - "description": "when if is true and has unevaluated properties", - "data": { - "foo": "then", - "bar": "bar", - "baz": "baz" - }, - "valid": false - }, - { - "description": "when if is false and has no unevaluated properties", - "data": { - "baz": "baz" - }, - "valid": true - }, - { - "description": "when if is false and has unevaluated properties", - "data": { - "foo": "else", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with if/then/else, else not defined", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "if": { - "properties": { - "foo": { "const": "then" } - }, - "required": ["foo"] - }, - "then": { - "properties": { - "bar": { "type": "string" } - }, - "required": ["bar"] - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "when if is true and has no unevaluated properties", - "data": { - "foo": "then", - "bar": "bar" - }, - "valid": true - }, - { - "description": "when if is true and has unevaluated properties", - "data": { - "foo": "then", - "bar": "bar", - "baz": "baz" - }, - "valid": false - }, - { - "description": "when if is false and has no unevaluated properties", - "data": { - "baz": "baz" - }, - "valid": false - }, - { - "description": "when if is false and has unevaluated properties", - "data": { - "foo": "else", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with dependentSchemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "dependentSchemas": { - "foo": { - "properties": { - "bar": { "const": "bar" } - }, - "required": ["bar"] - } - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with boolean schemas", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "allOf": [true], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with $ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "$ref": "#/$defs/bar", - "properties": { - "foo": { "type": "string" } - }, - "unevaluatedProperties": false, - "$defs": { - "bar": { - "properties": { - "bar": { "type": "string" } - } - } - } - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties before $ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "unevaluatedProperties": false, - "properties": { - "foo": { "type": "string" } - }, - "$ref": "#/$defs/bar", - "$defs": { - "bar": { - "properties": { - "bar": { "type": "string" } - } - } - } - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties with $dynamicRef", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://example.com/unevaluated-properties-with-dynamic-ref/derived", - - "$ref": "./baseSchema", - - "$defs": { - "derived": { - "$dynamicAnchor": "addons", - "properties": { - "bar": { "type": "string" } - } - }, - "baseSchema": { - "$id": "./baseSchema", - - "$comment": "unevaluatedProperties comes first so it's more likely to catch bugs with implementations that are sensitive to keyword ordering", - "unevaluatedProperties": false, - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "$dynamicRef": "#addons", - - "$defs": { - "defaultAddons": { - "$comment": "Needed to satisfy the bookending requirement", - "$dynamicAnchor": "addons" - } - } - } - } - }, - "tests": [ - { - "description": "with no unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - }, - { - "description": "with unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar", - "baz": "baz" - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties can't see inside cousins", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "properties": { - "foo": true - } - }, - { - "unevaluatedProperties": false - } - ] - }, - "tests": [ - { - "description": "always fails", - "data": { - "foo": 1 - }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties can't see inside cousins (reverse order)", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "allOf": [ - { - "unevaluatedProperties": false - }, - { - "properties": { - "foo": true - } - } - ] - }, - "tests": [ - { - "description": "always fails", - "data": { - "foo": 1 - }, - "valid": false - } - ] - }, - { - "description": "nested unevaluatedProperties, outer false, inner true, properties outside", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "allOf": [ - { - "unevaluatedProperties": true - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no nested unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with nested unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - } - ] - }, - { - "description": "nested unevaluatedProperties, outer false, inner true, properties inside", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "allOf": [ - { - "properties": { - "foo": { "type": "string" } - }, - "unevaluatedProperties": true - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "with no nested unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with nested unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": true - } - ] - }, - { - "description": "nested unevaluatedProperties, outer true, inner false, properties outside", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { "type": "string" } - }, - "allOf": [ - { - "unevaluatedProperties": false - } - ], - "unevaluatedProperties": true - }, - "tests": [ - { - "description": "with no nested unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": false - }, - { - "description": "with nested unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "nested unevaluatedProperties, outer true, inner false, properties inside", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "allOf": [ - { - "properties": { - "foo": { "type": "string" } - }, - "unevaluatedProperties": false - } - ], - "unevaluatedProperties": true - }, - "tests": [ - { - "description": "with no nested unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with nested unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "cousin unevaluatedProperties, true and false, true with properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "allOf": [ - { - "properties": { - "foo": { "type": "string" } - }, - "unevaluatedProperties": true - }, - { - "unevaluatedProperties": false - } - ] - }, - "tests": [ - { - "description": "with no nested unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": false - }, - { - "description": "with nested unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "cousin unevaluatedProperties, true and false, false with properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "allOf": [ - { - "unevaluatedProperties": true - }, - { - "properties": { - "foo": { "type": "string" } - }, - "unevaluatedProperties": false - } - ] - }, - "tests": [ - { - "description": "with no nested unevaluated properties", - "data": { - "foo": "foo" - }, - "valid": true - }, - { - "description": "with nested unevaluated properties", - "data": { - "foo": "foo", - "bar": "bar" - }, - "valid": false - } - ] - }, - { - "description": "property is evaluated in an uncle schema to unevaluatedProperties", - "comment": "see https://stackoverflow.com/questions/66936884/deeply-nested-unevaluatedproperties-and-their-expectations", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": { - "type": "object", - "properties": { - "bar": { - "type": "string" - } - }, - "unevaluatedProperties": false - } - }, - "anyOf": [ - { - "properties": { - "foo": { - "properties": { - "faz": { - "type": "string" - } - } - } - } - } - ] - }, - "tests": [ - { - "description": "no extra properties", - "data": { - "foo": { - "bar": "test" - } - }, - "valid": true - }, - { - "description": "uncle keyword evaluation is not significant", - "data": { - "foo": { - "bar": "test", - "faz": "test" - } - }, - "valid": false - } - ] - }, - { - "description": "in-place applicator siblings, allOf has unevaluated", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "allOf": [ - { - "properties": { - "foo": true - }, - "unevaluatedProperties": false - } - ], - "anyOf": [ - { - "properties": { - "bar": true - } - } - ] - }, - "tests": [ - { - "description": "base case: both properties present", - "data": { - "foo": 1, - "bar": 1 - }, - "valid": false - }, - { - "description": "in place applicator siblings, bar is missing", - "data": { - "foo": 1 - }, - "valid": true - }, - { - "description": "in place applicator siblings, foo is missing", - "data": { - "bar": 1 - }, - "valid": false - } - ] - }, - { - "description": "in-place applicator siblings, anyOf has unevaluated", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "allOf": [ - { - "properties": { - "foo": true - } - } - ], - "anyOf": [ - { - "properties": { - "bar": true - }, - "unevaluatedProperties": false - } - ] - }, - "tests": [ - { - "description": "base case: both properties present", - "data": { - "foo": 1, - "bar": 1 - }, - "valid": false - }, - { - "description": "in place applicator siblings, bar is missing", - "data": { - "foo": 1 - }, - "valid": false - }, - { - "description": "in place applicator siblings, foo is missing", - "data": { - "bar": 1 - }, - "valid": true - } - ] - }, - { - "description": "unevaluatedProperties + single cyclic ref", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "x": { "$ref": "#" } - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "Empty is valid", - "data": {}, - "valid": true - }, - { - "description": "Single is valid", - "data": { "x": {} }, - "valid": true - }, - { - "description": "Unevaluated on 1st level is invalid", - "data": { "x": {}, "y": {} }, - "valid": false - }, - { - "description": "Nested is valid", - "data": { "x": { "x": {} } }, - "valid": true - }, - { - "description": "Unevaluated on 2nd level is invalid", - "data": { "x": { "x": {}, "y": {} } }, - "valid": false - }, - { - "description": "Deep nested is valid", - "data": { "x": { "x": { "x": {} } } }, - "valid": true - }, - { - "description": "Unevaluated on 3rd level is invalid", - "data": { "x": { "x": { "x": {}, "y": {} } } }, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties + ref inside allOf / oneOf", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "one": { - "properties": { "a": true } - }, - "two": { - "required": ["x"], - "properties": { "x": true } - } - }, - "allOf": [ - { "$ref": "#/$defs/one" }, - { "properties": { "b": true } }, - { - "oneOf": [ - { "$ref": "#/$defs/two" }, - { - "required": ["y"], - "properties": { "y": true } - } - ] - } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "Empty is invalid (no x or y)", - "data": {}, - "valid": false - }, - { - "description": "a and b are invalid (no x or y)", - "data": { "a": 1, "b": 1 }, - "valid": false - }, - { - "description": "x and y are invalid", - "data": { "x": 1, "y": 1 }, - "valid": false - }, - { - "description": "a and x are valid", - "data": { "a": 1, "x": 1 }, - "valid": true - }, - { - "description": "a and y are valid", - "data": { "a": 1, "y": 1 }, - "valid": true - }, - { - "description": "a and b and x are valid", - "data": { "a": 1, "b": 1, "x": 1 }, - "valid": true - }, - { - "description": "a and b and y are valid", - "data": { "a": 1, "b": 1, "y": 1 }, - "valid": true - }, - { - "description": "a and b and x and y are invalid", - "data": { "a": 1, "b": 1, "x": 1, "y": 1 }, - "valid": false - } - ] - }, - { - "description": "dynamic evalation inside nested refs", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "one": { - "oneOf": [ - { "$ref": "#/$defs/two" }, - { "required": ["b"], "properties": { "b": true } }, - { "required": ["xx"], "patternProperties": { "x": true } }, - { "required": ["all"], "unevaluatedProperties": true } - ] - }, - "two": { - "oneOf": [ - { "required": ["c"], "properties": { "c": true } }, - { "required": ["d"], "properties": { "d": true } } - ] - } - }, - "oneOf": [ - { "$ref": "#/$defs/one" }, - { "required": ["a"], "properties": { "a": true } } - ], - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "Empty is invalid", - "data": {}, - "valid": false - }, - { - "description": "a is valid", - "data": { "a": 1 }, - "valid": true - }, - { - "description": "b is valid", - "data": { "b": 1 }, - "valid": true - }, - { - "description": "c is valid", - "data": { "c": 1 }, - "valid": true - }, - { - "description": "d is valid", - "data": { "d": 1 }, - "valid": true - }, - { - "description": "a + b is invalid", - "data": { "a": 1, "b": 1 }, - "valid": false - }, - { - "description": "a + c is invalid", - "data": { "a": 1, "c": 1 }, - "valid": false - }, - { - "description": "a + d is invalid", - "data": { "a": 1, "d": 1 }, - "valid": false - }, - { - "description": "b + c is invalid", - "data": { "b": 1, "c": 1 }, - "valid": false - }, - { - "description": "b + d is invalid", - "data": { "b": 1, "d": 1 }, - "valid": false - }, - { - "description": "c + d is invalid", - "data": { "c": 1, "d": 1 }, - "valid": false - }, - { - "description": "xx is valid", - "data": { "xx": 1 }, - "valid": true - }, - { - "description": "xx + foox is valid", - "data": { "xx": 1, "foox": 1 }, - "valid": true - }, - { - "description": "xx + foo is invalid", - "data": { "xx": 1, "foo": 1 }, - "valid": false - }, - { - "description": "xx + a is invalid", - "data": { "xx": 1, "a": 1 }, - "valid": false - }, - { - "description": "xx + b is invalid", - "data": { "xx": 1, "b": 1 }, - "valid": false - }, - { - "description": "xx + c is invalid", - "data": { "xx": 1, "c": 1 }, - "valid": false - }, - { - "description": "xx + d is invalid", - "data": { "xx": 1, "d": 1 }, - "valid": false - }, - { - "description": "all is valid", - "data": { "all": 1 }, - "valid": true - }, - { - "description": "all + foo is valid", - "data": { "all": 1, "foo": 1 }, - "valid": true - }, - { - "description": "all + a is invalid", - "data": { "all": 1, "a": 1 }, - "valid": false - } - ] - }, - { - "description": "non-object instances are valid", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "ignores booleans", - "data": true, - "valid": true - }, - { - "description": "ignores integers", - "data": 123, - "valid": true - }, - { - "description": "ignores floats", - "data": 1.0, - "valid": true - }, - { - "description": "ignores arrays", - "data": [], - "valid": true - }, - { - "description": "ignores strings", - "data": "foo", - "valid": true - }, - { - "description": "ignores null", - "data": null, - "valid": true - } - ] - }, - { - "description": "unevaluatedProperties with null valued instance properties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "unevaluatedProperties": { - "type": "null" - } - }, - "tests": [ - { - "description": "allows null valued properties", - "data": {"foo": null}, - "valid": true - } - ] - }, - { - "description": "unevaluatedProperties not affected by propertyNames", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "propertyNames": {"maxLength": 1}, - "unevaluatedProperties": { - "type": "number" - } - }, - "tests": [ - { - "description": "allows only number properties", - "data": {"a": 1}, - "valid": true - }, - { - "description": "string property is invalid", - "data": {"a": "b"}, - "valid": false - } - ] - }, - { - "description": "unevaluatedProperties can see annotations from if without then and else", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "if": { - "patternProperties": { - "foo": { - "type": "string" - } - } - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "valid in case if is evaluated", - "data": { - "foo": "a" - }, - "valid": true - }, - { - "description": "invalid in case if is evaluated", - "data": { - "bar": "a" - }, - "valid": false - } - ] - }, - { - "description": "dependentSchemas with unevaluatedProperties", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "properties": {"foo2": {}}, - "dependentSchemas": { - "foo" : {}, - "foo2": { - "properties": { - "bar":{} - } - } - }, - "unevaluatedProperties": false - }, - "tests": [ - { - "description": "unevaluatedProperties doesn't consider dependentSchemas", - "data": {"foo": ""}, - "valid": false - }, - { - "description": "unevaluatedProperties doesn't see bar when foo2 is absent", - "data": {"bar": ""}, - "valid": false - }, - { - "description": "unevaluatedProperties sees bar when foo2 is present", - "data": { "foo2": "", "bar": ""}, - "valid": true - } - ] - } -] diff --git a/jsonschema/testdata/draft2020-12/uniqueItems.json b/jsonschema/testdata/draft2020-12/uniqueItems.json deleted file mode 100644 index 4ea3bf98..00000000 --- a/jsonschema/testdata/draft2020-12/uniqueItems.json +++ /dev/null @@ -1,419 +0,0 @@ -[ - { - "description": "uniqueItems validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "uniqueItems": true - }, - "tests": [ - { - "description": "unique array of integers is valid", - "data": [1, 2], - "valid": true - }, - { - "description": "non-unique array of integers is invalid", - "data": [1, 1], - "valid": false - }, - { - "description": "non-unique array of more than two integers is invalid", - "data": [1, 2, 1], - "valid": false - }, - { - "description": "numbers are unique if mathematically unequal", - "data": [1.0, 1.00, 1], - "valid": false - }, - { - "description": "false is not equal to zero", - "data": [0, false], - "valid": true - }, - { - "description": "true is not equal to one", - "data": [1, true], - "valid": true - }, - { - "description": "unique array of strings is valid", - "data": ["foo", "bar", "baz"], - "valid": true - }, - { - "description": "non-unique array of strings is invalid", - "data": ["foo", "bar", "foo"], - "valid": false - }, - { - "description": "unique array of objects is valid", - "data": [{"foo": "bar"}, {"foo": "baz"}], - "valid": true - }, - { - "description": "non-unique array of objects is invalid", - "data": [{"foo": "bar"}, {"foo": "bar"}], - "valid": false - }, - { - "description": "property order of array of objects is ignored", - "data": [{"foo": "bar", "bar": "foo"}, {"bar": "foo", "foo": "bar"}], - "valid": false - }, - { - "description": "unique array of nested objects is valid", - "data": [ - {"foo": {"bar" : {"baz" : true}}}, - {"foo": {"bar" : {"baz" : false}}} - ], - "valid": true - }, - { - "description": "non-unique array of nested objects is invalid", - "data": [ - {"foo": {"bar" : {"baz" : true}}}, - {"foo": {"bar" : {"baz" : true}}} - ], - "valid": false - }, - { - "description": "unique array of arrays is valid", - "data": [["foo"], ["bar"]], - "valid": true - }, - { - "description": "non-unique array of arrays is invalid", - "data": [["foo"], ["foo"]], - "valid": false - }, - { - "description": "non-unique array of more than two arrays is invalid", - "data": [["foo"], ["bar"], ["foo"]], - "valid": false - }, - { - "description": "1 and true are unique", - "data": [1, true], - "valid": true - }, - { - "description": "0 and false are unique", - "data": [0, false], - "valid": true - }, - { - "description": "[1] and [true] are unique", - "data": [[1], [true]], - "valid": true - }, - { - "description": "[0] and [false] are unique", - "data": [[0], [false]], - "valid": true - }, - { - "description": "nested [1] and [true] are unique", - "data": [[[1], "foo"], [[true], "foo"]], - "valid": true - }, - { - "description": "nested [0] and [false] are unique", - "data": [[[0], "foo"], [[false], "foo"]], - "valid": true - }, - { - "description": "unique heterogeneous types are valid", - "data": [{}, [1], true, null, 1, "{}"], - "valid": true - }, - { - "description": "non-unique heterogeneous types are invalid", - "data": [{}, [1], true, null, {}, 1], - "valid": false - }, - { - "description": "different objects are unique", - "data": [{"a": 1, "b": 2}, {"a": 2, "b": 1}], - "valid": true - }, - { - "description": "objects are non-unique despite key order", - "data": [{"a": 1, "b": 2}, {"b": 2, "a": 1}], - "valid": false - }, - { - "description": "{\"a\": false} and {\"a\": 0} are unique", - "data": [{"a": false}, {"a": 0}], - "valid": true - }, - { - "description": "{\"a\": true} and {\"a\": 1} are unique", - "data": [{"a": true}, {"a": 1}], - "valid": true - } - ] - }, - { - "description": "uniqueItems with an array of items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [{"type": "boolean"}, {"type": "boolean"}], - "uniqueItems": true - }, - "tests": [ - { - "description": "[false, true] from items array is valid", - "data": [false, true], - "valid": true - }, - { - "description": "[true, false] from items array is valid", - "data": [true, false], - "valid": true - }, - { - "description": "[false, false] from items array is not valid", - "data": [false, false], - "valid": false - }, - { - "description": "[true, true] from items array is not valid", - "data": [true, true], - "valid": false - }, - { - "description": "unique array extended from [false, true] is valid", - "data": [false, true, "foo", "bar"], - "valid": true - }, - { - "description": "unique array extended from [true, false] is valid", - "data": [true, false, "foo", "bar"], - "valid": true - }, - { - "description": "non-unique array extended from [false, true] is not valid", - "data": [false, true, "foo", "foo"], - "valid": false - }, - { - "description": "non-unique array extended from [true, false] is not valid", - "data": [true, false, "foo", "foo"], - "valid": false - } - ] - }, - { - "description": "uniqueItems with an array of items and additionalItems=false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [{"type": "boolean"}, {"type": "boolean"}], - "uniqueItems": true, - "items": false - }, - "tests": [ - { - "description": "[false, true] from items array is valid", - "data": [false, true], - "valid": true - }, - { - "description": "[true, false] from items array is valid", - "data": [true, false], - "valid": true - }, - { - "description": "[false, false] from items array is not valid", - "data": [false, false], - "valid": false - }, - { - "description": "[true, true] from items array is not valid", - "data": [true, true], - "valid": false - }, - { - "description": "extra items are invalid even if unique", - "data": [false, true, null], - "valid": false - } - ] - }, - { - "description": "uniqueItems=false validation", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "uniqueItems": false - }, - "tests": [ - { - "description": "unique array of integers is valid", - "data": [1, 2], - "valid": true - }, - { - "description": "non-unique array of integers is valid", - "data": [1, 1], - "valid": true - }, - { - "description": "numbers are unique if mathematically unequal", - "data": [1.0, 1.00, 1], - "valid": true - }, - { - "description": "false is not equal to zero", - "data": [0, false], - "valid": true - }, - { - "description": "true is not equal to one", - "data": [1, true], - "valid": true - }, - { - "description": "unique array of objects is valid", - "data": [{"foo": "bar"}, {"foo": "baz"}], - "valid": true - }, - { - "description": "non-unique array of objects is valid", - "data": [{"foo": "bar"}, {"foo": "bar"}], - "valid": true - }, - { - "description": "unique array of nested objects is valid", - "data": [ - {"foo": {"bar" : {"baz" : true}}}, - {"foo": {"bar" : {"baz" : false}}} - ], - "valid": true - }, - { - "description": "non-unique array of nested objects is valid", - "data": [ - {"foo": {"bar" : {"baz" : true}}}, - {"foo": {"bar" : {"baz" : true}}} - ], - "valid": true - }, - { - "description": "unique array of arrays is valid", - "data": [["foo"], ["bar"]], - "valid": true - }, - { - "description": "non-unique array of arrays is valid", - "data": [["foo"], ["foo"]], - "valid": true - }, - { - "description": "1 and true are unique", - "data": [1, true], - "valid": true - }, - { - "description": "0 and false are unique", - "data": [0, false], - "valid": true - }, - { - "description": "unique heterogeneous types are valid", - "data": [{}, [1], true, null, 1], - "valid": true - }, - { - "description": "non-unique heterogeneous types are valid", - "data": [{}, [1], true, null, {}, 1], - "valid": true - } - ] - }, - { - "description": "uniqueItems=false with an array of items", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [{"type": "boolean"}, {"type": "boolean"}], - "uniqueItems": false - }, - "tests": [ - { - "description": "[false, true] from items array is valid", - "data": [false, true], - "valid": true - }, - { - "description": "[true, false] from items array is valid", - "data": [true, false], - "valid": true - }, - { - "description": "[false, false] from items array is valid", - "data": [false, false], - "valid": true - }, - { - "description": "[true, true] from items array is valid", - "data": [true, true], - "valid": true - }, - { - "description": "unique array extended from [false, true] is valid", - "data": [false, true, "foo", "bar"], - "valid": true - }, - { - "description": "unique array extended from [true, false] is valid", - "data": [true, false, "foo", "bar"], - "valid": true - }, - { - "description": "non-unique array extended from [false, true] is valid", - "data": [false, true, "foo", "foo"], - "valid": true - }, - { - "description": "non-unique array extended from [true, false] is valid", - "data": [true, false, "foo", "foo"], - "valid": true - } - ] - }, - { - "description": "uniqueItems=false with an array of items and additionalItems=false", - "schema": { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [{"type": "boolean"}, {"type": "boolean"}], - "uniqueItems": false, - "items": false - }, - "tests": [ - { - "description": "[false, true] from items array is valid", - "data": [false, true], - "valid": true - }, - { - "description": "[true, false] from items array is valid", - "data": [true, false], - "valid": true - }, - { - "description": "[false, false] from items array is valid", - "data": [false, false], - "valid": true - }, - { - "description": "[true, true] from items array is valid", - "data": [true, true], - "valid": true - }, - { - "description": "extra items are invalid even if unique", - "data": [false, true, null], - "valid": false - } - ] - } -] diff --git a/jsonschema/testdata/remotes/README.md b/jsonschema/testdata/remotes/README.md deleted file mode 100644 index 8a641dbd..00000000 --- a/jsonschema/testdata/remotes/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# JSON Schema test suite: remote references - -These files were copied from -https://github.com/json-schema-org/JSON-Schema-Test-Suite/tree/83e866b46c9f9e7082fd51e83a61c5f2145a1ab7/remotes. diff --git a/jsonschema/testdata/remotes/different-id-ref-string.json b/jsonschema/testdata/remotes/different-id-ref-string.json deleted file mode 100644 index 7f888609..00000000 --- a/jsonschema/testdata/remotes/different-id-ref-string.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "$id": "http://localhost:1234/real-id-ref-string.json", - "$defs": {"bar": {"type": "string"}}, - "$ref": "#/$defs/bar" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/baseUriChange/folderInteger.json b/jsonschema/testdata/remotes/draft2020-12/baseUriChange/folderInteger.json deleted file mode 100644 index 1f44a631..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/baseUriChange/folderInteger.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "integer" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolder/folderInteger.json b/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolder/folderInteger.json deleted file mode 100644 index 1f44a631..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolder/folderInteger.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "integer" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolderInSubschema/folderInteger.json b/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolderInSubschema/folderInteger.json deleted file mode 100644 index 1f44a631..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/baseUriChangeFolderInSubschema/folderInteger.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "integer" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/detached-dynamicref.json b/jsonschema/testdata/remotes/draft2020-12/detached-dynamicref.json deleted file mode 100644 index 07cce1da..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/detached-dynamicref.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "$id": "http://localhost:1234/draft2020-12/detached-dynamicref.json", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "foo": { - "$dynamicRef": "#detached" - }, - "detached": { - "$dynamicAnchor": "detached", - "type": "integer" - } - } -} \ No newline at end of file diff --git a/jsonschema/testdata/remotes/draft2020-12/detached-ref.json b/jsonschema/testdata/remotes/draft2020-12/detached-ref.json deleted file mode 100644 index 9c2dca93..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/detached-ref.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "$id": "http://localhost:1234/draft2020-12/detached-ref.json", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "foo": { - "$ref": "#detached" - }, - "detached": { - "$anchor": "detached", - "type": "integer" - } - } -} \ No newline at end of file diff --git a/jsonschema/testdata/remotes/draft2020-12/extendible-dynamic-ref.json b/jsonschema/testdata/remotes/draft2020-12/extendible-dynamic-ref.json deleted file mode 100644 index 65bc0c21..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/extendible-dynamic-ref.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "description": "extendible array", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/extendible-dynamic-ref.json", - "type": "object", - "properties": { - "elements": { - "type": "array", - "items": { - "$dynamicRef": "#elements" - } - } - }, - "required": ["elements"], - "additionalProperties": false, - "$defs": { - "elements": { - "$dynamicAnchor": "elements" - } - } -} diff --git a/jsonschema/testdata/remotes/draft2020-12/format-assertion-false.json b/jsonschema/testdata/remotes/draft2020-12/format-assertion-false.json deleted file mode 100644 index 43a711c9..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/format-assertion-false.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "$id": "http://localhost:1234/draft2020-12/format-assertion-false.json", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$vocabulary": { - "https://json-schema.org/draft/2020-12/vocab/core": true, - "https://json-schema.org/draft/2020-12/vocab/format-assertion": false - }, - "$dynamicAnchor": "meta", - "allOf": [ - { "$ref": "https://json-schema.org/draft/2020-12/meta/core" }, - { "$ref": "https://json-schema.org/draft/2020-12/meta/format-assertion" } - ] -} diff --git a/jsonschema/testdata/remotes/draft2020-12/format-assertion-true.json b/jsonschema/testdata/remotes/draft2020-12/format-assertion-true.json deleted file mode 100644 index 39c6b0ab..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/format-assertion-true.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "$id": "http://localhost:1234/draft2020-12/format-assertion-true.json", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$vocabulary": { - "https://json-schema.org/draft/2020-12/vocab/core": true, - "https://json-schema.org/draft/2020-12/vocab/format-assertion": true - }, - "$dynamicAnchor": "meta", - "allOf": [ - { "$ref": "https://json-schema.org/draft/2020-12/meta/core" }, - { "$ref": "https://json-schema.org/draft/2020-12/meta/format-assertion" } - ] -} diff --git a/jsonschema/testdata/remotes/draft2020-12/integer.json b/jsonschema/testdata/remotes/draft2020-12/integer.json deleted file mode 100644 index 1f44a631..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/integer.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "integer" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/locationIndependentIdentifier.json b/jsonschema/testdata/remotes/draft2020-12/locationIndependentIdentifier.json deleted file mode 100644 index 6565a1ee..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/locationIndependentIdentifier.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "refToInteger": { - "$ref": "#foo" - }, - "A": { - "$anchor": "foo", - "type": "integer" - } - } -} diff --git a/jsonschema/testdata/remotes/draft2020-12/metaschema-no-validation.json b/jsonschema/testdata/remotes/draft2020-12/metaschema-no-validation.json deleted file mode 100644 index 71be8b5d..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/metaschema-no-validation.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/metaschema-no-validation.json", - "$vocabulary": { - "https://json-schema.org/draft/2020-12/vocab/applicator": true, - "https://json-schema.org/draft/2020-12/vocab/core": true - }, - "$dynamicAnchor": "meta", - "allOf": [ - { "$ref": "https://json-schema.org/draft/2020-12/meta/applicator" }, - { "$ref": "https://json-schema.org/draft/2020-12/meta/core" } - ] -} diff --git a/jsonschema/testdata/remotes/draft2020-12/metaschema-optional-vocabulary.json b/jsonschema/testdata/remotes/draft2020-12/metaschema-optional-vocabulary.json deleted file mode 100644 index a6963e54..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/metaschema-optional-vocabulary.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/metaschema-optional-vocabulary.json", - "$vocabulary": { - "https://json-schema.org/draft/2020-12/vocab/validation": true, - "https://json-schema.org/draft/2020-12/vocab/core": true, - "http://localhost:1234/draft/2020-12/vocab/custom": false - }, - "$dynamicAnchor": "meta", - "allOf": [ - { "$ref": "https://json-schema.org/draft/2020-12/meta/validation" }, - { "$ref": "https://json-schema.org/draft/2020-12/meta/core" } - ] -} diff --git a/jsonschema/testdata/remotes/draft2020-12/name-defs.json b/jsonschema/testdata/remotes/draft2020-12/name-defs.json deleted file mode 100644 index 67bc33c5..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/name-defs.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "orNull": { - "anyOf": [ - { - "type": "null" - }, - { - "$ref": "#" - } - ] - } - }, - "type": "string" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/nested/foo-ref-string.json b/jsonschema/testdata/remotes/draft2020-12/nested/foo-ref-string.json deleted file mode 100644 index 29661ff9..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/nested/foo-ref-string.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "properties": { - "foo": {"$ref": "string.json"} - } -} diff --git a/jsonschema/testdata/remotes/draft2020-12/nested/string.json b/jsonschema/testdata/remotes/draft2020-12/nested/string.json deleted file mode 100644 index 6607ac53..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/nested/string.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "string" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/prefixItems.json b/jsonschema/testdata/remotes/draft2020-12/prefixItems.json deleted file mode 100644 index acd8293c..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/prefixItems.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "$id": "http://localhost:1234/draft2020-12/prefixItems.json", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "prefixItems": [ - {"type": "string"} - ] -} diff --git a/jsonschema/testdata/remotes/draft2020-12/ref-and-defs.json b/jsonschema/testdata/remotes/draft2020-12/ref-and-defs.json deleted file mode 100644 index 16d30fa3..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/ref-and-defs.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/ref-and-defs.json", - "$defs": { - "inner": { - "properties": { - "bar": { "type": "string" } - } - } - }, - "$ref": "#/$defs/inner" -} diff --git a/jsonschema/testdata/remotes/draft2020-12/subSchemas.json b/jsonschema/testdata/remotes/draft2020-12/subSchemas.json deleted file mode 100644 index 1bb4846d..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/subSchemas.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$defs": { - "integer": { - "type": "integer" - }, - "refToInteger": { - "$ref": "#/$defs/integer" - } - } -} diff --git a/jsonschema/testdata/remotes/draft2020-12/tree.json b/jsonschema/testdata/remotes/draft2020-12/tree.json deleted file mode 100644 index b07555fb..00000000 --- a/jsonschema/testdata/remotes/draft2020-12/tree.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "description": "tree schema, extensible", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "http://localhost:1234/draft2020-12/tree.json", - "$dynamicAnchor": "node", - - "type": "object", - "properties": { - "data": true, - "children": { - "type": "array", - "items": { - "$dynamicRef": "#node" - } - } - } -} diff --git a/jsonschema/testdata/remotes/nested-absolute-ref-to-string.json b/jsonschema/testdata/remotes/nested-absolute-ref-to-string.json deleted file mode 100644 index f46c7616..00000000 --- a/jsonschema/testdata/remotes/nested-absolute-ref-to-string.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "$defs": { - "bar": { - "$id": "http://localhost:1234/the-nested-id.json", - "type": "string" - } - }, - "$ref": "http://localhost:1234/the-nested-id.json" -} diff --git a/jsonschema/testdata/remotes/urn-ref-string.json b/jsonschema/testdata/remotes/urn-ref-string.json deleted file mode 100644 index aca2211b..00000000 --- a/jsonschema/testdata/remotes/urn-ref-string.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "$id": "urn:uuid:feebdaed-ffff-0000-ffff-0000deadbeef", - "$defs": {"bar": {"type": "string"}}, - "$ref": "#/$defs/bar" -} diff --git a/jsonschema/util.go b/jsonschema/util.go deleted file mode 100644 index 71c34439..00000000 --- a/jsonschema/util.go +++ /dev/null @@ -1,420 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "bytes" - "cmp" - "encoding/binary" - "encoding/json" - "fmt" - "hash/maphash" - "math" - "math/big" - "reflect" - "slices" - "sync" - - "github.com/modelcontextprotocol/go-sdk/internal/util" -) - -// Equal reports whether two Go values representing JSON values are equal according -// to the JSON Schema spec. -// The values must not contain cycles. -// See https://json-schema.org/draft/2020-12/json-schema-core#section-4.2.2. -// It behaves like reflect.DeepEqual, except that numbers are compared according -// to mathematical equality. -func Equal(x, y any) bool { - return equalValue(reflect.ValueOf(x), reflect.ValueOf(y)) -} - -func equalValue(x, y reflect.Value) bool { - // Copied from src/reflect/deepequal.go, omitting the visited check (because JSON - // values are trees). - if !x.IsValid() || !y.IsValid() { - return x.IsValid() == y.IsValid() - } - - // Treat numbers specially. - rx, ok1 := jsonNumber(x) - ry, ok2 := jsonNumber(y) - if ok1 && ok2 { - return rx.Cmp(ry) == 0 - } - if x.Kind() != y.Kind() { - return false - } - switch x.Kind() { - case reflect.Array: - if x.Len() != y.Len() { - return false - } - for i := range x.Len() { - if !equalValue(x.Index(i), y.Index(i)) { - return false - } - } - return true - case reflect.Slice: - if x.IsNil() != y.IsNil() { - return false - } - if x.Len() != y.Len() { - return false - } - if x.UnsafePointer() == y.UnsafePointer() { - return true - } - // Special case for []byte, which is common. - if x.Type().Elem().Kind() == reflect.Uint8 && x.Type() == y.Type() { - return bytes.Equal(x.Bytes(), y.Bytes()) - } - for i := range x.Len() { - if !equalValue(x.Index(i), y.Index(i)) { - return false - } - } - return true - case reflect.Interface: - if x.IsNil() || y.IsNil() { - return x.IsNil() == y.IsNil() - } - return equalValue(x.Elem(), y.Elem()) - case reflect.Pointer: - if x.UnsafePointer() == y.UnsafePointer() { - return true - } - return equalValue(x.Elem(), y.Elem()) - case reflect.Struct: - t := x.Type() - if t != y.Type() { - return false - } - for i := range t.NumField() { - sf := t.Field(i) - if !sf.IsExported() { - continue - } - if !equalValue(x.FieldByIndex(sf.Index), y.FieldByIndex(sf.Index)) { - return false - } - } - return true - case reflect.Map: - if x.IsNil() != y.IsNil() { - return false - } - if x.Len() != y.Len() { - return false - } - if x.UnsafePointer() == y.UnsafePointer() { - return true - } - iter := x.MapRange() - for iter.Next() { - vx := iter.Value() - vy := y.MapIndex(iter.Key()) - if !vy.IsValid() || !equalValue(vx, vy) { - return false - } - } - return true - case reflect.Func: - if x.Type() != y.Type() { - return false - } - if x.IsNil() && y.IsNil() { - return true - } - panic("cannot compare functions") - case reflect.String: - return x.String() == y.String() - case reflect.Bool: - return x.Bool() == y.Bool() - // Ints, uints and floats handled in jsonNumber, at top of function. - default: - panic(fmt.Sprintf("unsupported kind: %s", x.Kind())) - } -} - -// hashValue adds v to the data hashed by h. v must not have cycles. -// hashValue panics if the value contains functions or channels, or maps whose -// key type is not string. -// It ignores unexported fields of structs. -// Calls to hashValue with the equal values (in the sense -// of [Equal]) result in the same sequence of values written to the hash. -func hashValue(h *maphash.Hash, v reflect.Value) { - // TODO: replace writes of basic types with WriteComparable in 1.24. - - writeUint := func(u uint64) { - var buf [8]byte - binary.BigEndian.PutUint64(buf[:], u) - h.Write(buf[:]) - } - - var write func(reflect.Value) - write = func(v reflect.Value) { - if r, ok := jsonNumber(v); ok { - // We want 1.0 and 1 to hash the same. - // big.Rats are always normalized, so they will be. - // We could do this more efficiently by handling the int and float cases - // separately, but that's premature. - writeUint(uint64(r.Sign() + 1)) - h.Write(r.Num().Bytes()) - h.Write(r.Denom().Bytes()) - return - } - switch v.Kind() { - case reflect.Invalid: - h.WriteByte(0) - case reflect.String: - h.WriteString(v.String()) - case reflect.Bool: - if v.Bool() { - h.WriteByte(1) - } else { - h.WriteByte(0) - } - case reflect.Complex64, reflect.Complex128: - c := v.Complex() - writeUint(math.Float64bits(real(c))) - writeUint(math.Float64bits(imag(c))) - case reflect.Array, reflect.Slice: - // Although we could treat []byte more efficiently, - // JSON values are unlikely to contain them. - writeUint(uint64(v.Len())) - for i := range v.Len() { - write(v.Index(i)) - } - case reflect.Interface, reflect.Pointer: - write(v.Elem()) - case reflect.Struct: - t := v.Type() - for i := range t.NumField() { - if sf := t.Field(i); sf.IsExported() { - write(v.FieldByIndex(sf.Index)) - } - } - case reflect.Map: - if v.Type().Key().Kind() != reflect.String { - panic("map with non-string key") - } - // Sort the keys so the hash is deterministic. - keys := v.MapKeys() - // Write the length. That distinguishes between, say, two consecutive - // maps with disjoint keys from one map that has the items of both. - writeUint(uint64(len(keys))) - slices.SortFunc(keys, func(x, y reflect.Value) int { return cmp.Compare(x.String(), y.String()) }) - for _, k := range keys { - write(k) - write(v.MapIndex(k)) - } - // Ints, uints and floats handled in jsonNumber, at top of function. - default: - panic(fmt.Sprintf("unsupported kind: %s", v.Kind())) - } - } - - write(v) -} - -// jsonNumber converts a numeric value or a json.Number to a [big.Rat]. -// If v is not a number, it returns nil, false. -func jsonNumber(v reflect.Value) (*big.Rat, bool) { - r := new(big.Rat) - switch { - case !v.IsValid(): - return nil, false - case v.CanInt(): - r.SetInt64(v.Int()) - case v.CanUint(): - r.SetUint64(v.Uint()) - case v.CanFloat(): - r.SetFloat64(v.Float()) - default: - jn, ok := v.Interface().(json.Number) - if !ok { - return nil, false - } - if _, ok := r.SetString(jn.String()); !ok { - // This can fail in rare cases; for example, "1e9999999". - // That is a valid JSON number, since the spec puts no limit on the size - // of the exponent. - return nil, false - } - } - return r, true -} - -// jsonType returns a string describing the type of the JSON value, -// as described in the JSON Schema specification: -// https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.1. -// It returns "", false if the value is not valid JSON. -func jsonType(v reflect.Value) (string, bool) { - if !v.IsValid() { - // Not v.IsNil(): a nil []any is still a JSON array. - return "null", true - } - if v.CanInt() || v.CanUint() { - return "integer", true - } - if v.CanFloat() { - if _, f := math.Modf(v.Float()); f == 0 { - return "integer", true - } - return "number", true - } - switch v.Kind() { - case reflect.Bool: - return "boolean", true - case reflect.String: - return "string", true - case reflect.Slice, reflect.Array: - return "array", true - case reflect.Map, reflect.Struct: - return "object", true - default: - return "", false - } -} - -func assert(cond bool, msg string) { - if !cond { - panic("assertion failed: " + msg) - } -} - -// marshalStructWithMap marshals its first argument to JSON, treating the field named -// mapField as an embedded map. The first argument must be a pointer to -// a struct. The underlying type of mapField must be a map[string]any, and it must have -// a "-" json tag, meaning it will not be marshaled. -// -// For example, given this struct: -// -// type S struct { -// A int -// Extra map[string] any `json:"-"` -// } -// -// and this value: -// -// s := S{A: 1, Extra: map[string]any{"B": 2}} -// -// the call marshalJSONWithMap(s, "Extra") would return -// -// {"A": 1, "B": 2} -// -// It is an error if the map contains the same key as another struct field's -// JSON name. -// -// marshalStructWithMap calls json.Marshal on a value of type T, so T must not -// have a MarshalJSON method that calls this function, on pain of infinite regress. -// -// Note that there is a similar function in mcp/util.go, but they are not the same. -// Here the function requires `-` json tag, does not clear the mapField map, -// and handles embedded struct due to the implementation of jsonNames in this package. -// -// TODO: avoid this restriction on T by forcing it to marshal in a default way. -// See https://go.dev/play/p/EgXKJHxEx_R. -func marshalStructWithMap[T any](s *T, mapField string) ([]byte, error) { - // Marshal the struct and the map separately, and concatenate the bytes. - // This strategy is dramatically less complicated than - // constructing a synthetic struct or map with the combined keys. - if s == nil { - return []byte("null"), nil - } - s2 := *s - vMapField := reflect.ValueOf(&s2).Elem().FieldByName(mapField) - mapVal := vMapField.Interface().(map[string]any) - - // Check for duplicates. - names := jsonNames(reflect.TypeFor[T]()) - for key := range mapVal { - if names[key] { - return nil, fmt.Errorf("map key %q duplicates struct field", key) - } - } - - structBytes, err := json.Marshal(s2) - if err != nil { - return nil, fmt.Errorf("marshalStructWithMap(%+v): %w", s, err) - } - if len(mapVal) == 0 { - return structBytes, nil - } - mapBytes, err := json.Marshal(mapVal) - if err != nil { - return nil, err - } - if len(structBytes) == 2 { // must be "{}" - return mapBytes, nil - } - // "{X}" + "{Y}" => "{X,Y}" - res := append(structBytes[:len(structBytes)-1], ',') - res = append(res, mapBytes[1:]...) - return res, nil -} - -// unmarshalStructWithMap is the inverse of marshalStructWithMap. -// T has the same restrictions as in that function. -// -// Note that there is a similar function in mcp/util.go, but they are not the same. -// Here jsonNames also returns fields from embedded structs, hence this function -// handles embedded structs as well. -func unmarshalStructWithMap[T any](data []byte, v *T, mapField string) error { - // Unmarshal into the struct, ignoring unknown fields. - if err := json.Unmarshal(data, v); err != nil { - return err - } - // Unmarshal into the map. - m := map[string]any{} - if err := json.Unmarshal(data, &m); err != nil { - return err - } - // Delete from the map the fields of the struct. - for n := range jsonNames(reflect.TypeFor[T]()) { - delete(m, n) - } - if len(m) != 0 { - reflect.ValueOf(v).Elem().FieldByName(mapField).Set(reflect.ValueOf(m)) - } - return nil -} - -var jsonNamesMap sync.Map // from reflect.Type to map[string]bool - -// jsonNames returns the set of JSON object keys that t will marshal into, -// including fields from embedded structs in t. -// t must be a struct type. -// -// Note that there is a similar function in mcp/util.go, but they are not the same -// Here the function recurses over embedded structs and includes fields from them. -func jsonNames(t reflect.Type) map[string]bool { - // Lock not necessary: at worst we'll duplicate work. - if val, ok := jsonNamesMap.Load(t); ok { - return val.(map[string]bool) - } - m := map[string]bool{} - for i := range t.NumField() { - field := t.Field(i) - // handle embedded structs - if field.Anonymous { - fieldType := field.Type - if fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - for n := range jsonNames(fieldType) { - m[n] = true - } - continue - } - info := util.FieldJSONInfo(field) - if !info.Omit { - m[info.Name] = true - } - } - jsonNamesMap.Store(t, m) - return m -} diff --git a/jsonschema/util_test.go b/jsonschema/util_test.go deleted file mode 100644 index 03ccb4d7..00000000 --- a/jsonschema/util_test.go +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "encoding/json" - "hash/maphash" - "reflect" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" -) - -func TestEqual(t *testing.T) { - for _, tt := range []struct { - x1, x2 any - want bool - }{ - {0, 1, false}, - {1, 1.0, true}, - {nil, 0, false}, - {"0", 0, false}, - {2.5, 2.5, true}, - {[]int{1, 2}, []float64{1.0, 2.0}, true}, - {[]int(nil), []int{}, false}, - {[]map[string]any(nil), []map[string]any{}, false}, - { - map[string]any{"a": 1, "b": 2.0}, - map[string]any{"a": 1.0, "b": 2}, - true, - }, - } { - check := func(x1, x2 any, want bool) { - t.Helper() - if got := Equal(x1, x2); got != want { - t.Errorf("jsonEqual(%#v, %#v) = %t, want %t", x1, x2, got, want) - } - } - check(tt.x1, tt.x1, true) - check(tt.x2, tt.x2, true) - check(tt.x1, tt.x2, tt.want) - check(tt.x2, tt.x1, tt.want) - } -} - -func TestJSONType(t *testing.T) { - for _, tt := range []struct { - val string - want string - }{ - {`null`, "null"}, - {`0`, "integer"}, - {`0.0`, "integer"}, - {`1e2`, "integer"}, - {`0.1`, "number"}, - {`""`, "string"}, - {`true`, "boolean"}, - {`[]`, "array"}, - {`{}`, "object"}, - } { - var val any - if err := json.Unmarshal([]byte(tt.val), &val); err != nil { - t.Fatal(err) - } - got, ok := jsonType(reflect.ValueOf(val)) - if !ok { - t.Fatalf("jsonType failed on %q", tt.val) - } - if got != tt.want { - t.Errorf("%s: got %q, want %q", tt.val, got, tt.want) - } - - } -} - -func TestHash(t *testing.T) { - x := map[string]any{ - "s": []any{1, "foo", nil, true}, - "f": 2.5, - "m": map[string]any{ - "n": json.Number("123.456"), - "schema": &Schema{Type: "integer", UniqueItems: true}, - }, - "c": 1.2 + 3.4i, - "n": nil, - } - - seed := maphash.MakeSeed() - - hash := func(x any) uint64 { - var h maphash.Hash - h.SetSeed(seed) - hashValue(&h, reflect.ValueOf(x)) - return h.Sum64() - } - - want := hash(x) - // Run several times to verify consistency. - for range 10 { - if got := hash(x); got != want { - t.Errorf("hash values differ: %d vs. %d", got, want) - } - } - - // Check mathematically equal values. - nums := []any{ - 5, - uint(5), - 5.0, - json.Number("5"), - json.Number("5.00"), - } - for i, n := range nums { - if i == 0 { - want = hash(n) - } else if got := hash(n); got != want { - t.Errorf("hashes differ between %v (%[1]T) and %v (%[2]T)", nums[0], n) - } - } - - // Check that a bare JSON `null` is OK. - var null any - if err := json.Unmarshal([]byte(`null`), &null); err != nil { - t.Fatal(err) - } - _ = hash(null) -} - -func TestMarshalStructWithMap(t *testing.T) { - type S struct { - A int - B string `json:"b,omitempty"` - u bool - M map[string]any `json:"-"` - } - t.Run("basic", func(t *testing.T) { - s := S{A: 1, B: "two", M: map[string]any{"!@#": true}} - got, err := marshalStructWithMap(&s, "M") - if err != nil { - t.Fatal(err) - } - want := `{"A":1,"b":"two","!@#":true}` - if g := string(got); g != want { - t.Errorf("\ngot %s\nwant %s", g, want) - } - - var un S - if err := unmarshalStructWithMap(got, &un, "M"); err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(s, un, cmpopts.IgnoreUnexported(S{})); diff != "" { - t.Errorf("mismatch (-want, +got):\n%s", diff) - } - }) - t.Run("duplicate", func(t *testing.T) { - s := S{A: 1, B: "two", M: map[string]any{"b": "dup"}} - _, err := marshalStructWithMap(&s, "M") - if err == nil || !strings.Contains(err.Error(), "duplicate") { - t.Errorf("got %v, want error with 'duplicate'", err) - } - }) - t.Run("embedded", func(t *testing.T) { - type Embedded struct { - A int - B int - Extra map[string]any `json:"-"` - } - type S struct { - C int - Embedded - } - s := S{C: 1, Embedded: Embedded{A: 2, B: 3, Extra: map[string]any{"d": 4, "e": 5}}} - got, err := marshalStructWithMap(&s, "Extra") - if err != nil { - t.Fatal(err) - } - want := `{"C":1,"A":2,"B":3,"d":4,"e":5}` - if g := string(got); g != want { - t.Errorf("got %v, want %v", g, want) - } - }) -} diff --git a/jsonschema/validate.go b/jsonschema/validate.go deleted file mode 100644 index a04e42bd..00000000 --- a/jsonschema/validate.go +++ /dev/null @@ -1,754 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "encoding/json" - "fmt" - "hash/maphash" - "iter" - "math" - "math/big" - "reflect" - "slices" - "strings" - "sync" - "unicode/utf8" - - "github.com/modelcontextprotocol/go-sdk/internal/util" -) - -// The value of the "$schema" keyword for the version that we can validate. -const draft202012 = "https://json-schema.org/draft/2020-12/schema" - -// Validate validates the instance, which must be a JSON value, against the schema. -// It returns nil if validation is successful or an error if it is not. -// If the schema type is "object", instance can be a map[string]any or a struct. -func (rs *Resolved) Validate(instance any) error { - if s := rs.root.Schema; s != "" && s != draft202012 { - return fmt.Errorf("cannot validate version %s, only %s", s, draft202012) - } - st := &state{rs: rs} - return st.validate(reflect.ValueOf(instance), st.rs.root, nil) -} - -// validateDefaults walks the schema tree. If it finds a default, it validates it -// against the schema containing it. -// -// TODO(jba): account for dynamic refs. This algorithm simple-mindedly -// treats each schema with a default as its own root. -func (rs *Resolved) validateDefaults() error { - if s := rs.root.Schema; s != "" && s != draft202012 { - return fmt.Errorf("cannot validate version %s, only %s", s, draft202012) - } - st := &state{rs: rs} - for s := range rs.root.all() { - // We checked for nil schemas in [Schema.Resolve]. - assert(s != nil, "nil schema") - if s.DynamicRef != "" { - return fmt.Errorf("jsonschema: %s: validateDefaults does not support dynamic refs", s) - } - if s.Default != nil { - var d any - if err := json.Unmarshal(s.Default, &d); err != nil { - return fmt.Errorf("unmarshaling default value of schema %s: %w", s, err) - } - if err := st.validate(reflect.ValueOf(d), s, nil); err != nil { - return err - } - } - } - return nil -} - -// state is the state of single call to ResolvedSchema.Validate. -type state struct { - rs *Resolved - // stack holds the schemas from recursive calls to validate. - // These are the "dynamic scopes" used to resolve dynamic references. - // https://json-schema.org/draft/2020-12/json-schema-core#scopes - stack []*Schema -} - -// validate validates the reflected value of the instance. -func (st *state) validate(instance reflect.Value, schema *Schema, callerAnns *annotations) (err error) { - defer util.Wrapf(&err, "validating %s", schema) - - // Maintain a stack for dynamic schema resolution. - st.stack = append(st.stack, schema) // push - defer func() { - st.stack = st.stack[:len(st.stack)-1] // pop - }() - - // We checked for nil schemas in [Schema.Resolve]. - assert(schema != nil, "nil schema") - - // Step through interfaces and pointers. - for instance.Kind() == reflect.Pointer || instance.Kind() == reflect.Interface { - instance = instance.Elem() - } - - // type: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.1 - if schema.Type != "" || schema.Types != nil { - gotType, ok := jsonType(instance) - if !ok { - return fmt.Errorf("type: %v of type %[1]T is not a valid JSON value", instance) - } - if schema.Type != "" { - // "number" subsumes integers - if !(gotType == schema.Type || - gotType == "integer" && schema.Type == "number") { - return fmt.Errorf("type: %v has type %q, want %q", instance, gotType, schema.Type) - } - } else { - if !(slices.Contains(schema.Types, gotType) || (gotType == "integer" && slices.Contains(schema.Types, "number"))) { - return fmt.Errorf("type: %v has type %q, want one of %q", - instance, gotType, strings.Join(schema.Types, ", ")) - } - } - } - // enum: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.2 - if schema.Enum != nil { - ok := false - for _, e := range schema.Enum { - if equalValue(reflect.ValueOf(e), instance) { - ok = true - break - } - } - if !ok { - return fmt.Errorf("enum: %v does not equal any of: %v", instance, schema.Enum) - } - } - - // const: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.1.3 - if schema.Const != nil { - if !equalValue(reflect.ValueOf(*schema.Const), instance) { - return fmt.Errorf("const: %v does not equal %v", instance, *schema.Const) - } - } - - // numbers: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.2 - if schema.MultipleOf != nil || schema.Minimum != nil || schema.Maximum != nil || schema.ExclusiveMinimum != nil || schema.ExclusiveMaximum != nil { - n, ok := jsonNumber(instance) - if ok { // these keywords don't apply to non-numbers - if schema.MultipleOf != nil { - // TODO: validate MultipleOf as non-zero. - // The test suite assumes floats. - nf, _ := n.Float64() // don't care if it's exact or not - if _, f := math.Modf(nf / *schema.MultipleOf); f != 0 { - return fmt.Errorf("multipleOf: %s is not a multiple of %f", n, *schema.MultipleOf) - } - } - - m := new(big.Rat) // reuse for all of the following - cmp := func(f float64) int { return n.Cmp(m.SetFloat64(f)) } - - if schema.Minimum != nil && cmp(*schema.Minimum) < 0 { - return fmt.Errorf("minimum: %s is less than %f", n, *schema.Minimum) - } - if schema.Maximum != nil && cmp(*schema.Maximum) > 0 { - return fmt.Errorf("maximum: %s is greater than %f", n, *schema.Maximum) - } - if schema.ExclusiveMinimum != nil && cmp(*schema.ExclusiveMinimum) <= 0 { - return fmt.Errorf("exclusiveMinimum: %s is less than or equal to %f", n, *schema.ExclusiveMinimum) - } - if schema.ExclusiveMaximum != nil && cmp(*schema.ExclusiveMaximum) >= 0 { - return fmt.Errorf("exclusiveMaximum: %s is greater than or equal to %f", n, *schema.ExclusiveMaximum) - } - } - } - - // strings: https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.3 - if instance.Kind() == reflect.String && (schema.MinLength != nil || schema.MaxLength != nil || schema.Pattern != "") { - str := instance.String() - n := utf8.RuneCountInString(str) - if schema.MinLength != nil { - if m := *schema.MinLength; n < m { - return fmt.Errorf("minLength: %q contains %d Unicode code points, fewer than %d", str, n, m) - } - } - if schema.MaxLength != nil { - if m := *schema.MaxLength; n > m { - return fmt.Errorf("maxLength: %q contains %d Unicode code points, more than %d", str, n, m) - } - } - - if schema.Pattern != "" && !schema.pattern.MatchString(str) { - return fmt.Errorf("pattern: %q does not match regular expression %q", str, schema.Pattern) - } - } - - var anns annotations // all the annotations for this call and child calls - - // $ref: https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.1 - if schema.Ref != "" { - if err := st.validate(instance, schema.resolvedRef, &anns); err != nil { - return err - } - } - - // $dynamicRef: https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.2 - if schema.DynamicRef != "" { - // The ref behaves lexically or dynamically, but not both. - assert((schema.resolvedDynamicRef == nil) != (schema.dynamicRefAnchor == ""), - "DynamicRef not resolved properly") - if schema.resolvedDynamicRef != nil { - // Same as $ref. - if err := st.validate(instance, schema.resolvedDynamicRef, &anns); err != nil { - return err - } - } else { - // Dynamic behavior. - // Look for the base of the outermost schema on the stack with this dynamic - // anchor. (Yes, outermost: the one farthest from here. This the opposite - // of how ordinary dynamic variables behave.) - // Why the base of the schema being validated and not the schema itself? - // Because the base is the scope for anchors. In fact it's possible to - // refer to a schema that is not on the stack, but a child of some base - // on the stack. - // For an example, search for "detached" in testdata/draft2020-12/dynamicRef.json. - var dynamicSchema *Schema - for _, s := range st.stack { - info, ok := s.base.anchors[schema.dynamicRefAnchor] - if ok && info.dynamic { - dynamicSchema = info.schema - break - } - } - if dynamicSchema == nil { - return fmt.Errorf("missing dynamic anchor %q", schema.dynamicRefAnchor) - } - if err := st.validate(instance, dynamicSchema, &anns); err != nil { - return err - } - } - } - - // logic - // https://json-schema.org/draft/2020-12/json-schema-core#section-10.2 - // These must happen before arrays and objects because if they evaluate an item or property, - // then the unevaluatedItems/Properties schemas don't apply to it. - // See https://json-schema.org/draft/2020-12/json-schema-core#section-11.2, paragraph 4. - // - // If any of these fail, then validation fails, even if there is an unevaluatedXXX - // keyword in the schema. The spec is unclear about this, but that is the intention. - - valid := func(s *Schema, anns *annotations) bool { return st.validate(instance, s, anns) == nil } - - if schema.AllOf != nil { - for _, ss := range schema.AllOf { - if err := st.validate(instance, ss, &anns); err != nil { - return err - } - } - } - if schema.AnyOf != nil { - // We must visit them all, to collect annotations. - ok := false - for _, ss := range schema.AnyOf { - if valid(ss, &anns) { - ok = true - } - } - if !ok { - return fmt.Errorf("anyOf: did not validate against any of %v", schema.AnyOf) - } - } - if schema.OneOf != nil { - // Exactly one. - var okSchema *Schema - for _, ss := range schema.OneOf { - if valid(ss, &anns) { - if okSchema != nil { - return fmt.Errorf("oneOf: validated against both %v and %v", okSchema, ss) - } - okSchema = ss - } - } - if okSchema == nil { - return fmt.Errorf("oneOf: did not validate against any of %v", schema.OneOf) - } - } - if schema.Not != nil { - // Ignore annotations from "not". - if valid(schema.Not, nil) { - return fmt.Errorf("not: validated against %v", schema.Not) - } - } - if schema.If != nil { - var ss *Schema - if valid(schema.If, &anns) { - ss = schema.Then - } else { - ss = schema.Else - } - if ss != nil { - if err := st.validate(instance, ss, &anns); err != nil { - return err - } - } - } - - // arrays - // TODO(jba): consider arrays of structs. - if instance.Kind() == reflect.Array || instance.Kind() == reflect.Slice { - // https://json-schema.org/draft/2020-12/json-schema-core#section-10.3.1 - // This validate call doesn't collect annotations for the items of the instance; they are separate - // instances in their own right. - // TODO(jba): if the test suite doesn't cover this case, add a test. For example, nested arrays. - for i, ischema := range schema.PrefixItems { - if i >= instance.Len() { - break // shorter is OK - } - if err := st.validate(instance.Index(i), ischema, nil); err != nil { - return err - } - } - anns.noteEndIndex(min(len(schema.PrefixItems), instance.Len())) - - if schema.Items != nil { - for i := len(schema.PrefixItems); i < instance.Len(); i++ { - if err := st.validate(instance.Index(i), schema.Items, nil); err != nil { - return err - } - } - // Note that all the items in this array have been validated. - anns.allItems = true - } - - nContains := 0 - if schema.Contains != nil { - for i := range instance.Len() { - if err := st.validate(instance.Index(i), schema.Contains, nil); err == nil { - nContains++ - anns.noteIndex(i) - } - } - if nContains == 0 && (schema.MinContains == nil || *schema.MinContains > 0) { - return fmt.Errorf("contains: %s does not have an item matching %s", instance, schema.Contains) - } - } - - // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.4 - // TODO(jba): check that these next four keywords' values are integers. - if schema.MinContains != nil && schema.Contains != nil { - if m := *schema.MinContains; nContains < m { - return fmt.Errorf("minContains: contains validated %d items, less than %d", nContains, m) - } - } - if schema.MaxContains != nil && schema.Contains != nil { - if m := *schema.MaxContains; nContains > m { - return fmt.Errorf("maxContains: contains validated %d items, greater than %d", nContains, m) - } - } - if schema.MinItems != nil { - if m := *schema.MinItems; instance.Len() < m { - return fmt.Errorf("minItems: array length %d is less than %d", instance.Len(), m) - } - } - if schema.MaxItems != nil { - if m := *schema.MaxItems; instance.Len() > m { - return fmt.Errorf("maxItems: array length %d is greater than %d", instance.Len(), m) - } - } - if schema.UniqueItems { - if instance.Len() > 1 { - // Hash each item and compare the hashes. - // If two hashes differ, the items differ. - // If two hashes are the same, compare the collisions for equality. - // (The same logic as hash table lookup.) - // TODO(jba): Use container/hash.Map when it becomes available (https://go.dev/issue/69559), - hashes := map[uint64][]int{} // from hash to indices - seed := maphash.MakeSeed() - for i := range instance.Len() { - item := instance.Index(i) - var h maphash.Hash - h.SetSeed(seed) - hashValue(&h, item) - hv := h.Sum64() - if sames := hashes[hv]; len(sames) > 0 { - for _, j := range sames { - if equalValue(item, instance.Index(j)) { - return fmt.Errorf("uniqueItems: array items %d and %d are equal", i, j) - } - } - } - hashes[hv] = append(hashes[hv], i) - } - } - } - - // https://json-schema.org/draft/2020-12/json-schema-core#section-11.2 - if schema.UnevaluatedItems != nil && !anns.allItems { - // Apply this subschema to all items in the array that haven't been successfully validated. - // That includes validations by subschemas on the same instance, like allOf. - for i := anns.endIndex; i < instance.Len(); i++ { - if !anns.evaluatedIndexes[i] { - if err := st.validate(instance.Index(i), schema.UnevaluatedItems, nil); err != nil { - return err - } - } - } - anns.allItems = true - } - } - - // objects - // https://json-schema.org/draft/2020-12/json-schema-core#section-10.3.2 - if instance.Kind() == reflect.Map || instance.Kind() == reflect.Struct { - if instance.Kind() == reflect.Map { - if kt := instance.Type().Key(); kt.Kind() != reflect.String { - return fmt.Errorf("map key type %s is not a string", kt) - } - } - // Track the evaluated properties for just this schema, to support additionalProperties. - // If we used anns here, then we'd be including properties evaluated in subschemas - // from allOf, etc., which additionalProperties shouldn't observe. - evalProps := map[string]bool{} - for prop, subschema := range schema.Properties { - val := property(instance, prop) - if !val.IsValid() { - // It's OK if the instance doesn't have the property. - continue - } - // If the instance is a struct and an optional property has the zero - // value, then we could interpret it as present or missing. Be generous: - // assume it's missing, and thus always validates successfully. - if instance.Kind() == reflect.Struct && val.IsZero() && !schema.isRequired[prop] { - continue - } - if err := st.validate(val, subschema, nil); err != nil { - return err - } - evalProps[prop] = true - } - if len(schema.PatternProperties) > 0 { - for prop, val := range properties(instance) { - // Check every matching pattern. - for re, schema := range schema.patternProperties { - if re.MatchString(prop) { - if err := st.validate(val, schema, nil); err != nil { - return err - } - evalProps[prop] = true - } - } - } - } - if schema.AdditionalProperties != nil { - // Apply to all properties not handled above. - for prop, val := range properties(instance) { - if !evalProps[prop] { - if err := st.validate(val, schema.AdditionalProperties, nil); err != nil { - return err - } - evalProps[prop] = true - } - } - } - anns.noteProperties(evalProps) - if schema.PropertyNames != nil { - // Note: properties unnecessarily fetches each value. We could define a propertyNames function - // if performance ever matters. - for prop := range properties(instance) { - if err := st.validate(reflect.ValueOf(prop), schema.PropertyNames, nil); err != nil { - return err - } - } - } - - // https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-01#section-6.5 - var min, max int - if schema.MinProperties != nil || schema.MaxProperties != nil { - min, max = numPropertiesBounds(instance, schema.isRequired) - } - if schema.MinProperties != nil { - if n, m := max, *schema.MinProperties; n < m { - return fmt.Errorf("minProperties: object has %d properties, less than %d", n, m) - } - } - if schema.MaxProperties != nil { - if n, m := min, *schema.MaxProperties; n > m { - return fmt.Errorf("maxProperties: object has %d properties, greater than %d", n, m) - } - } - - hasProperty := func(prop string) bool { - return property(instance, prop).IsValid() - } - - missingProperties := func(props []string) []string { - var missing []string - for _, p := range props { - if !hasProperty(p) { - missing = append(missing, p) - } - } - return missing - } - - if schema.Required != nil { - if m := missingProperties(schema.Required); len(m) > 0 { - return fmt.Errorf("required: missing properties: %q", m) - } - } - if schema.DependentRequired != nil { - // "Validation succeeds if, for each name that appears in both the instance - // and as a name within this keyword's value, every item in the corresponding - // array is also the name of a property in the instance." §6.5.4 - for dprop, reqs := range schema.DependentRequired { - if hasProperty(dprop) { - if m := missingProperties(reqs); len(m) > 0 { - return fmt.Errorf("dependentRequired[%q]: missing properties %q", dprop, m) - } - } - } - } - - // https://json-schema.org/draft/2020-12/json-schema-core#section-10.2.2.4 - if schema.DependentSchemas != nil { - // This does not collect annotations, although it seems like it should. - for dprop, ss := range schema.DependentSchemas { - if hasProperty(dprop) { - // TODO: include dependentSchemas[dprop] in the errors. - err := st.validate(instance, ss, &anns) - if err != nil { - return err - } - } - } - } - if schema.UnevaluatedProperties != nil && !anns.allProperties { - // This looks a lot like AdditionalProperties, but depends on in-place keywords like allOf - // in addition to sibling keywords. - for prop, val := range properties(instance) { - if !anns.evaluatedProperties[prop] { - if err := st.validate(val, schema.UnevaluatedProperties, nil); err != nil { - return err - } - } - } - // The spec says the annotation should be the set of evaluated properties, but we can optimize - // by setting a single boolean, since after this succeeds all properties will be validated. - // See https://json-schema.slack.com/archives/CT7FF623C/p1745592564381459. - anns.allProperties = true - } - } - - if callerAnns != nil { - // Our caller wants to know what we've validated. - callerAnns.merge(&anns) - } - return nil -} - -// resolveDynamicRef returns the schema referred to by the argument schema's -// $dynamicRef value. -// It returns an error if the dynamic reference has no referent. -// If there is no $dynamicRef, resolveDynamicRef returns nil, nil. -// See https://json-schema.org/draft/2020-12/json-schema-core#section-8.2.3.2. -func (st *state) resolveDynamicRef(schema *Schema) (*Schema, error) { - if schema.DynamicRef == "" { - return nil, nil - } - // The ref behaves lexically or dynamically, but not both. - assert((schema.resolvedDynamicRef == nil) != (schema.dynamicRefAnchor == ""), - "DynamicRef not statically resolved properly") - if r := schema.resolvedDynamicRef; r != nil { - // Same as $ref. - return r, nil - } - // Dynamic behavior. - // Look for the base of the outermost schema on the stack with this dynamic - // anchor. (Yes, outermost: the one farthest from here. This the opposite - // of how ordinary dynamic variables behave.) - // Why the base of the schema being validated and not the schema itself? - // Because the base is the scope for anchors. In fact it's possible to - // refer to a schema that is not on the stack, but a child of some base - // on the stack. - // For an example, search for "detached" in testdata/draft2020-12/dynamicRef.json. - for _, s := range st.stack { - info, ok := s.base.anchors[schema.dynamicRefAnchor] - if ok && info.dynamic { - return info.schema, nil - } - } - return nil, fmt.Errorf("missing dynamic anchor %q", schema.dynamicRefAnchor) -} - -// ApplyDefaults modifies an instance by applying the schema's defaults to it. If -// a schema or sub-schema has a default, then a corresponding zero instance value -// is set to the default. -// -// The JSON Schema specification does not describe how defaults should be interpreted. -// This method honors defaults only on properties, and only those that are not required. -// If the instance is a map and the property is missing, the property is added to -// the map with the default. -// If the instance is a struct, the field corresponding to the property exists, and -// its value is zero, the field is set to the default. -// ApplyDefaults can panic if a default cannot be assigned to a field. -// -// The argument must be a pointer to the instance. -// (In case we decide that top-level defaults are meaningful.) -// -// It is recommended to first call Resolve with a ValidateDefaults option of true, -// then call this method, and lastly call Validate. -// -// TODO(jba): consider what defaults on top-level or array instances might mean. -// TODO(jba): follow $ref and $dynamicRef -// TODO(jba): apply defaults on sub-schemas to corresponding sub-instances. -func (rs *Resolved) ApplyDefaults(instancep any) error { - st := &state{rs: rs} - return st.applyDefaults(reflect.ValueOf(instancep), rs.root) -} - -// Leave this as a potentially recursive helper function, because we'll surely want -// to apply defaults on sub-schemas someday. -func (st *state) applyDefaults(instancep reflect.Value, schema *Schema) (err error) { - defer util.Wrapf(&err, "applyDefaults: schema %s, instance %v", schema, instancep) - - instance := instancep.Elem() - if instance.Kind() == reflect.Map || instance.Kind() == reflect.Struct { - if instance.Kind() == reflect.Map { - if kt := instance.Type().Key(); kt.Kind() != reflect.String { - return fmt.Errorf("map key type %s is not a string", kt) - } - } - for prop, subschema := range schema.Properties { - // Ignore defaults on required properties. (A required property shouldn't have a default.) - if schema.isRequired[prop] { - continue - } - val := property(instance, prop) - switch instance.Kind() { - case reflect.Map: - // If there is a default for this property, and the map key is missing, - // set the map value to the default. - if subschema.Default != nil && !val.IsValid() { - // Create an lvalue, since map values aren't addressable. - lvalue := reflect.New(instance.Type().Elem()) - if err := json.Unmarshal(subschema.Default, lvalue.Interface()); err != nil { - return err - } - instance.SetMapIndex(reflect.ValueOf(prop), lvalue.Elem()) - } - case reflect.Struct: - // If there is a default for this property, and the field exists but is zero, - // set the field to the default. - if subschema.Default != nil && val.IsValid() && val.IsZero() { - if err := json.Unmarshal(subschema.Default, val.Addr().Interface()); err != nil { - return err - } - } - default: - panic(fmt.Sprintf("applyDefaults: property %s: bad value %s of kind %s", - prop, instance, instance.Kind())) - } - } - } - return nil -} - -// property returns the value of the property of v with the given name, or the invalid -// reflect.Value if there is none. -// If v is a map, the property is the value of the map whose key is name. -// If v is a struct, the property is the value of the field with the given name according -// to the encoding/json package (see [jsonName]). -// If v is anything else, property panics. -func property(v reflect.Value, name string) reflect.Value { - switch v.Kind() { - case reflect.Map: - return v.MapIndex(reflect.ValueOf(name)) - case reflect.Struct: - props := structPropertiesOf(v.Type()) - // Ignore nonexistent properties. - if sf, ok := props[name]; ok { - return v.FieldByIndex(sf.Index) - } - return reflect.Value{} - default: - panic(fmt.Sprintf("property(%q): bad value %s of kind %s", name, v, v.Kind())) - } -} - -// properties returns an iterator over the names and values of all properties -// in v, which must be a map or a struct. -// If a struct, zero-valued properties that are marked omitempty or omitzero -// are excluded. -func properties(v reflect.Value) iter.Seq2[string, reflect.Value] { - return func(yield func(string, reflect.Value) bool) { - switch v.Kind() { - case reflect.Map: - for k, e := range v.Seq2() { - if !yield(k.String(), e) { - return - } - } - case reflect.Struct: - for name, sf := range structPropertiesOf(v.Type()) { - val := v.FieldByIndex(sf.Index) - if val.IsZero() { - info := util.FieldJSONInfo(sf) - if info.Settings["omitempty"] || info.Settings["omitzero"] { - continue - } - } - if !yield(name, val) { - return - } - } - default: - panic(fmt.Sprintf("bad value %s of kind %s", v, v.Kind())) - } - } -} - -// numPropertiesBounds returns bounds on the number of v's properties. -// v must be a map or a struct. -// If v is a map, both bounds are the map's size. -// If v is a struct, the max is the number of struct properties. -// But since we don't know whether a zero value indicates a missing optional property -// or not, be generous and use the number of non-zero properties as the min. -func numPropertiesBounds(v reflect.Value, isRequired map[string]bool) (int, int) { - switch v.Kind() { - case reflect.Map: - return v.Len(), v.Len() - case reflect.Struct: - sp := structPropertiesOf(v.Type()) - min := 0 - for prop, sf := range sp { - if !v.FieldByIndex(sf.Index).IsZero() || isRequired[prop] { - min++ - } - } - return min, len(sp) - default: - panic(fmt.Sprintf("properties: bad value: %s of kind %s", v, v.Kind())) - } -} - -// A propertyMap is a map from property name to struct field index. -type propertyMap = map[string]reflect.StructField - -var structProperties sync.Map // from reflect.Type to propertyMap - -// structPropertiesOf returns the JSON Schema properties for the struct type t. -// The caller must not mutate the result. -func structPropertiesOf(t reflect.Type) propertyMap { - // Mutex not necessary: at worst we'll recompute the same value. - if props, ok := structProperties.Load(t); ok { - return props.(propertyMap) - } - props := map[string]reflect.StructField{} - for _, sf := range reflect.VisibleFields(t) { - info := util.FieldJSONInfo(sf) - if !info.Omit { - props[info.Name] = sf - } - } - structProperties.Store(t, props) - return props -} diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go deleted file mode 100644 index 7fb52d18..00000000 --- a/jsonschema/validate_test.go +++ /dev/null @@ -1,294 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package jsonschema - -import ( - "encoding/json" - "fmt" - "net/url" - "os" - "path/filepath" - "reflect" - "strings" - "testing" -) - -// The test for validation uses the official test suite, expressed as a set of JSON files. -// Each file is an array of group objects. - -// A testGroup consists of a schema and some tests on it. -type testGroup struct { - Description string - Schema *Schema - Tests []test -} - -// A test consists of a JSON instance to be validated and the expected result. -type test struct { - Description string - Data any - Valid bool -} - -func TestValidate(t *testing.T) { - files, err := filepath.Glob(filepath.FromSlash("testdata/draft2020-12/*.json")) - if err != nil { - t.Fatal(err) - } - if len(files) == 0 { - t.Fatal("no files") - } - for _, file := range files { - base := filepath.Base(file) - t.Run(base, func(t *testing.T) { - data, err := os.ReadFile(file) - if err != nil { - t.Fatal(err) - } - var groups []testGroup - if err := json.Unmarshal(data, &groups); err != nil { - t.Fatal(err) - } - for _, g := range groups { - t.Run(g.Description, func(t *testing.T) { - rs, err := g.Schema.Resolve(&ResolveOptions{Loader: loadRemote}) - if err != nil { - t.Fatal(err) - } - for _, test := range g.Tests { - t.Run(test.Description, func(t *testing.T) { - err = rs.Validate(test.Data) - if err != nil && test.Valid { - t.Errorf("wanted success, but failed with: %v", err) - } - if err == nil && !test.Valid { - t.Error("succeeded but wanted failure") - } - if t.Failed() { - t.Errorf("schema: %s", g.Schema.json()) - t.Fatalf("instance: %v (%[1]T)", test.Data) - } - }) - } - }) - } - }) - } -} - -func TestValidateErrors(t *testing.T) { - schema := &Schema{ - PrefixItems: []*Schema{{Contains: &Schema{Type: "integer"}}}, - } - rs, err := schema.Resolve(nil) - if err != nil { - t.Fatal(err) - } - err = rs.Validate([]any{[]any{"1"}}) - want := "prefixItems/0" - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("error:\n%s\ndoes not contain %q", err, want) - } -} - -func TestValidateDefaults(t *testing.T) { - s := &Schema{ - Properties: map[string]*Schema{ - "a": {Type: "integer", Default: mustMarshal(1)}, - "b": {Type: "string", Default: mustMarshal("s")}, - }, - Default: mustMarshal(map[string]any{"a": 1, "b": "two"}), - } - if _, err := s.Resolve(&ResolveOptions{ValidateDefaults: true}); err != nil { - t.Fatal(err) - } - - s = &Schema{ - Properties: map[string]*Schema{ - "a": {Type: "integer", Default: mustMarshal(3)}, - "b": {Type: "string", Default: mustMarshal("s")}, - }, - Default: mustMarshal(map[string]any{"a": 1, "b": 2}), - } - _, err := s.Resolve(&ResolveOptions{ValidateDefaults: true}) - want := `has type "integer", want "string"` - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("Resolve returned error %q, want %q", err, want) - } -} - -func TestApplyDefaults(t *testing.T) { - schema := &Schema{ - Properties: map[string]*Schema{ - "A": {Default: mustMarshal(1)}, - "B": {Default: mustMarshal(2)}, - "C": {Default: mustMarshal(3)}, - }, - Required: []string{"C"}, - } - rs, err := schema.Resolve(&ResolveOptions{ValidateDefaults: true}) - if err != nil { - t.Fatal(err) - } - - type S struct{ A, B, C int } - for _, tt := range []struct { - instancep any // pointer to instance value - want any // desired value (not a pointer) - }{ - { - &map[string]any{"B": 0}, - map[string]any{ - "A": float64(1), // filled from default - "B": 0, // untouched: it was already there - // "C" not added: it is required (Validate will catch that) - }, - }, - { - &S{B: 1}, - S{ - A: 1, // filled from default - B: 1, // untouched: non-zero - C: 0, // untouched: required - }, - }, - } { - if err := rs.ApplyDefaults(tt.instancep); err != nil { - t.Fatal(err) - } - got := reflect.ValueOf(tt.instancep).Elem().Interface() // dereference the pointer - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("\ngot %#v\nwant %#v", got, tt.want) - } - } -} - -func TestStructInstance(t *testing.T) { - instance := struct { - I int - B bool `json:"b"` - P *int // either missing or nil - u int // unexported: not a property - }{1, true, nil, 0} - - for _, tt := range []struct { - s Schema - want bool - }{ - { - Schema{MinProperties: Ptr(4)}, - false, - }, - { - Schema{MinProperties: Ptr(3)}, - true, // P interpreted as present - }, - { - Schema{MaxProperties: Ptr(1)}, - false, - }, - { - Schema{MaxProperties: Ptr(2)}, - true, // P interpreted as absent - }, - { - Schema{Required: []string{"i"}}, // the name is "I" - false, - }, - { - Schema{Required: []string{"B"}}, // the name is "b" - false, - }, - { - Schema{PropertyNames: &Schema{MinLength: Ptr(2)}}, - false, - }, - { - Schema{Properties: map[string]*Schema{"b": {Type: "boolean"}}}, - true, - }, - { - Schema{Properties: map[string]*Schema{"b": {Type: "number"}}}, - false, - }, - { - Schema{Required: []string{"I"}}, - true, - }, - { - Schema{Required: []string{"I", "P"}}, - true, // P interpreted as present - }, - { - Schema{Required: []string{"I", "P"}, Properties: map[string]*Schema{"P": {Type: "number"}}}, - false, // P interpreted as present, but not a number - }, - { - Schema{Required: []string{"I"}, Properties: map[string]*Schema{"P": {Type: "number"}}}, - true, // P not required, so interpreted as absent - }, - { - Schema{Required: []string{"I"}, AdditionalProperties: falseSchema()}, - false, - }, - { - Schema{DependentRequired: map[string][]string{"b": {"u"}}}, - false, - }, - { - Schema{DependentSchemas: map[string]*Schema{"b": falseSchema()}}, - false, - }, - { - Schema{UnevaluatedProperties: falseSchema()}, - false, - }, - } { - res, err := tt.s.Resolve(nil) - if err != nil { - t.Fatal(err) - } - err = res.Validate(instance) - if err == nil && !tt.want { - t.Errorf("succeeded unexpectedly\nschema = %s", tt.s.json()) - } else if err != nil && tt.want { - t.Errorf("Validate: %v\nschema = %s", err, tt.s.json()) - } - } -} - -func mustMarshal(x any) json.RawMessage { - data, err := json.Marshal(x) - if err != nil { - panic(err) - } - return json.RawMessage(data) -} - -// loadRemote loads a remote reference used in the test suite. -func loadRemote(uri *url.URL) (*Schema, error) { - // Anything with localhost:1234 refers to the remotes directory in the test suite repo. - if uri.Host == "localhost:1234" { - return loadSchemaFromFile(filepath.FromSlash(filepath.Join("testdata/remotes", uri.Path))) - } - // One test needs the meta-schema files. - const metaPrefix = "https://json-schema.org/draft/2020-12/" - if after, ok := strings.CutPrefix(uri.String(), metaPrefix); ok { - return loadSchemaFromFile(filepath.FromSlash("meta-schemas/draft2020-12/" + after + ".json")) - } - return nil, fmt.Errorf("don't know how to load %s", uri) -} - -func loadSchemaFromFile(filename string) (*Schema, error) { - data, err := os.ReadFile(filename) - if err != nil { - return nil, err - } - var s Schema - if err := json.Unmarshal(data, &s); err != nil { - return nil, fmt.Errorf("unmarshaling JSON at %s: %w", filename, err) - } - return &s, nil -} diff --git a/mcp/client.go b/mcp/client.go index 3a935040..dea3e854 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -6,36 +6,43 @@ package mcp import ( "context" + "encoding/json" + "fmt" "iter" "slices" "sync" "time" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) // A Client is an MCP client, which may be connected to an MCP server // using the [Client.Connect] method. type Client struct { - name string - version string + impl *Implementation opts ClientOptions mu sync.Mutex roots *featureSet[*Root] sessions []*ClientSession - sendingMethodHandler_ MethodHandler[*ClientSession] - receivingMethodHandler_ MethodHandler[*ClientSession] + sendingMethodHandler_ MethodHandler + receivingMethodHandler_ MethodHandler } -// NewClient creates a new Client. +// NewClient creates a new [Client]. // // Use [Client.Connect] to connect it to an MCP server. // +// The first argument must not be nil. +// // If non-nil, the provided options configure the Client. -func NewClient(name, version string, opts *ClientOptions) *Client { +func NewClient(impl *Implementation, opts *ClientOptions) *Client { + if impl == nil { + panic("nil Implementation") + } c := &Client{ - name: name, - version: version, + impl: impl, roots: newFeatureSet(func(r *Root) string { return r.URI }), sendingMethodHandler_: defaultSendingMethodHandler[*ClientSession], receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], @@ -48,15 +55,23 @@ func NewClient(name, version string, opts *ClientOptions) *Client { // ClientOptions configures the behavior of the client. type ClientOptions struct { - // Handler for sampling. - // Called when a server calls CreateMessage. - CreateMessageHandler func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) + // CreateMessageHandler handles incoming requests for sampling/createMessage. + // + // Setting CreateMessageHandler to a non-nil value causes the client to + // advertise the sampling capability. + CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) + // ElicitationHandler handles incoming requests for elicitation/create. + // + // Setting ElicitationHandler to a non-nil value causes the client to + // advertise the elicitation capability. + ElicitationHandler func(context.Context, *ElicitRequest) (*ElicitResult, error) // Handlers for notifications from the server. - ToolListChangedHandler func(context.Context, *ClientSession, *ToolListChangedParams) - PromptListChangedHandler func(context.Context, *ClientSession, *PromptListChangedParams) - ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams) - LoggingMessageHandler func(context.Context, *ClientSession, *LoggingMessageParams) - ProgressNotificationHandler func(context.Context, *ClientSession, *ProgressNotificationParams) + ToolListChangedHandler func(context.Context, *ToolListChangedRequest) + PromptListChangedHandler func(context.Context, *PromptListChangedRequest) + ResourceListChangedHandler func(context.Context, *ResourceListChangedRequest) + ResourceUpdatedHandler func(context.Context, *ResourceUpdatedNotificationRequest) + LoggingMessageHandler func(context.Context, *LoggingMessageRequest) + ProgressNotificationHandler func(context.Context, *ProgressNotificationClientRequest) // If non-zero, defines an interval for regular "ping" requests. // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. @@ -65,10 +80,11 @@ type ClientOptions struct { // bind implements the binder[*ClientSession] interface, so that Clients can // be connected using [connect]. -func (c *Client) bind(conn *jsonrpc2.Connection) *ClientSession { - cs := &ClientSession{ - conn: conn, - client: c, +func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession { + assert(mcpConn != nil && conn != nil, "nil connection") + cs := &ClientSession{conn: conn, mcpConn: mcpConn, client: c, onClose: onClose} + if state != nil { + cs.state = *state } c.mu.Lock() defer c.mu.Unlock() @@ -86,40 +102,63 @@ func (c *Client) disconnect(cs *ClientSession) { }) } +// TODO: Consider exporting this type and its field. +type unsupportedProtocolVersionError struct { + version string +} + +func (e unsupportedProtocolVersionError) Error() string { + return fmt.Sprintf("unsupported protocol version: %q", e.version) +} + +// ClientSessionOptions is reserved for future use. +type ClientSessionOptions struct{} + +func (c *Client) capabilities() *ClientCapabilities { + caps := &ClientCapabilities{} + caps.Roots.ListChanged = true + if c.opts.CreateMessageHandler != nil { + caps.Sampling = &SamplingCapabilities{} + } + if c.opts.ElicitationHandler != nil { + caps.Elicitation = &ElicitationCapabilities{} + } + return caps +} + // Connect begins an MCP session by connecting to a server over the given -// transport, and initializing the session. +// transport. The resulting session is initialized, and ready to use. // // Typically, it is the responsibility of the client to close the connection // when it is no longer needed. However, if the connection is closed by the // server, calls or notifications will return an error wrapping // [ErrConnectionClosed]. -func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, err error) { - cs, err = connect(ctx, t, c) +func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) { + cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil) if err != nil { return nil, err } - caps := &ClientCapabilities{} - caps.Roots.ListChanged = true - if c.opts.CreateMessageHandler != nil { - caps.Sampling = &SamplingCapabilities{} - } - params := &InitializeParams{ - ClientInfo: &implementation{Name: c.name, Version: c.version}, - Capabilities: caps, - ProtocolVersion: "2025-03-26", - } - // TODO(rfindley): handle protocol negotiation gracefully. If the server - // responds with 2024-11-05, surface that failure to the caller of connect, - // so that they can choose a different transport. - res, err := handleSend[*InitializeResult](ctx, cs, methodInitialize, params) + ProtocolVersion: latestProtocolVersion, + ClientInfo: c.impl, + Capabilities: c.capabilities(), + } + req := &InitializeRequest{Session: cs, Params: params} + res, err := handleSend[*InitializeResult](ctx, methodInitialize, req) if err != nil { _ = cs.Close() return nil, err } - cs.initializeResult = res - if err := handleNotify(ctx, cs, notificationInitialized, &InitializedParams{}); err != nil { + if !slices.Contains(supportedProtocolVersions, res.ProtocolVersion) { + return nil, unsupportedProtocolVersionError{res.ProtocolVersion} + } + cs.state.InitializeResult = res + if hc, ok := cs.mcpConn.(clientConnection); ok { + hc.sessionUpdated(cs.state) + } + req2 := &initializedClientRequest{Session: cs, Params: &InitializedParams{}} + if err := handleNotify(ctx, notificationInitialized, req2); err != nil { _ = cs.Close() return nil, err } @@ -135,25 +174,32 @@ func (c *Client) Connect(ctx context.Context, t Transport) (cs *ClientSession, e // methods can be used to send requests or notifications to the server. Create // a session by calling [Client.Connect]. // -// Call [ClientSession.Close] to close the connection, or await client -// termination with [ServerSession.Wait]. +// Call [ClientSession.Close] to close the connection, or await server +// termination with [ClientSession.Wait]. type ClientSession struct { - conn *jsonrpc2.Connection - client *Client - initializeResult *InitializeResult - keepaliveCancel context.CancelFunc - mcpConn Connection + onClose func() + + conn *jsonrpc2.Connection + client *Client + keepaliveCancel context.CancelFunc + mcpConn Connection + + // No mutex is (currently) required to guard the session state, because it is + // only set synchronously during Client.Connect. + state clientSessionState } -func (cs *ClientSession) setConn(c Connection) { - cs.mcpConn = c +type clientSessionState struct { + InitializeResult *InitializeResult } +func (cs *ClientSession) InitializeResult() *InitializeResult { return cs.state.InitializeResult } + func (cs *ClientSession) ID() string { - if cs.mcpConn == nil { - return "" + if c, ok := cs.mcpConn.(hasSessionID); ok { + return c.SessionID() } - return cs.mcpConn.SessionID() + return "" } // Close performs a graceful close of the connection, preventing new requests @@ -168,7 +214,13 @@ func (cs *ClientSession) Close() error { if cs.keepaliveCancel != nil { cs.keepaliveCancel() } - return cs.conn.Close() + err := cs.conn.Close() + + if cs.onClose != nil { + cs.onClose() + } + + return err } // Wait waits for the connection to be closed by the server. @@ -190,23 +242,22 @@ func (c *Client) AddRoots(roots ...*Root) { if len(roots) == 0 { return } - c.changeAndNotify(notificationRootsListChanged, &RootsListChangedParams{}, + changeAndNotify(c, notificationRootsListChanged, &RootsListChangedParams{}, func() bool { c.roots.add(roots...); return true }) } // RemoveRoots removes the roots with the given URIs, // and notifies any connected servers if the list has changed. // It is not an error to remove a nonexistent root. -// TODO: notification func (c *Client) RemoveRoots(uris ...string) { - c.changeAndNotify(notificationRootsListChanged, &RootsListChangedParams{}, + changeAndNotify(c, notificationRootsListChanged, &RootsListChangedParams{}, func() bool { return c.roots.remove(uris...) }) } // changeAndNotify is called when a feature is added or removed. // It calls change, which should do the work and report whether a change actually occurred. // If there was a change, it notifies a snapshot of the sessions. -func (c *Client) changeAndNotify(notification string, params Params, change func() bool) { +func changeAndNotify[P Params](c *Client, notification string, params P, change func() bool) { var sessions []*ClientSession // Lock for the change, but not for the notification. c.mu.Lock() @@ -217,7 +268,7 @@ func (c *Client) changeAndNotify(notification string, params Params, change func notifySessions(sessions, notification, params) } -func (c *Client) listRoots(_ context.Context, _ *ClientSession, _ *ListRootsParams) (*ListRootsResult, error) { +func (c *Client) listRoots(_ context.Context, req *ListRootsRequest) (*ListRootsResult, error) { c.mu.Lock() defer c.mu.Unlock() roots := slices.Collect(c.roots.all()) @@ -229,12 +280,177 @@ func (c *Client) listRoots(_ context.Context, _ *ClientSession, _ *ListRootsPara }, nil } -func (c *Client) createMessage(ctx context.Context, cs *ClientSession, params *CreateMessageParams) (*CreateMessageResult, error) { +func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { if c.opts.CreateMessageHandler == nil { // TODO: wrap or annotate this error? Pick a standard code? - return nil, &jsonrpc2.WireError{Code: CodeUnsupportedMethod, Message: "client does not support CreateMessage"} + return nil, jsonrpc2.NewError(CodeUnsupportedMethod, "client does not support CreateMessage") + } + return c.opts.CreateMessageHandler(ctx, req) +} + +func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + if c.opts.ElicitationHandler == nil { + // TODO: wrap or annotate this error? Pick a standard code? + return nil, jsonrpc2.NewError(CodeUnsupportedMethod, "client does not support elicitation") + } + + // Validate that the requested schema only contains top-level properties without nesting + if err := validateElicitSchema(req.Params.RequestedSchema); err != nil { + return nil, jsonrpc2.NewError(CodeInvalidParams, err.Error()) + } + + res, err := c.opts.ElicitationHandler(ctx, req) + if err != nil { + return nil, err + } + + // Validate elicitation result content against requested schema + if req.Params.RequestedSchema != nil && res.Content != nil { + // TODO: is this the correct behavior if validation fails? + // It isn't the *server's* params that are invalid, so why would we return + // this code to the server? + resolved, err := req.Params.RequestedSchema.Resolve(nil) + if err != nil { + return nil, jsonrpc2.NewError(CodeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err)) + } + + if err := resolved.Validate(res.Content); err != nil { + return nil, jsonrpc2.NewError(CodeInvalidParams, fmt.Sprintf("elicitation result content does not match requested schema: %v", err)) + } + } + + return res, nil +} + +// validateElicitSchema validates that the schema conforms to MCP elicitation schema requirements. +// Per the MCP specification, elicitation schemas are limited to flat objects with primitive properties only. +func validateElicitSchema(schema *jsonschema.Schema) error { + if schema == nil { + return nil // nil schema is allowed + } + + // The root schema must be of type "object" if specified + if schema.Type != "" && schema.Type != "object" { + return fmt.Errorf("elicit schema must be of type 'object', got %q", schema.Type) + } + + // Check if the schema has properties + if schema.Properties != nil { + for propName, propSchema := range schema.Properties { + if propSchema == nil { + continue + } + + if err := validateElicitProperty(propName, propSchema); err != nil { + return err + } + } + } + + return nil +} + +// validateElicitProperty validates a single property in an elicitation schema. +func validateElicitProperty(propName string, propSchema *jsonschema.Schema) error { + // Check if this property has nested properties (not allowed) + if len(propSchema.Properties) > 0 { + return fmt.Errorf("elicit schema property %q contains nested properties, only primitive properties are allowed", propName) + } + + // Validate based on the property type - only primitives are supported + switch propSchema.Type { + case "string": + return validateElicitStringProperty(propName, propSchema) + case "number", "integer": + return validateElicitNumberProperty(propName, propSchema) + case "boolean": + return validateElicitBooleanProperty(propName, propSchema) + default: + return fmt.Errorf("elicit schema property %q has unsupported type %q, only string, number, integer, and boolean are allowed", propName, propSchema.Type) + } +} + +// validateElicitStringProperty validates string-type properties, including enums. +func validateElicitStringProperty(propName string, propSchema *jsonschema.Schema) error { + // Handle enum validation (enums are a special case of strings) + if len(propSchema.Enum) > 0 { + // Enums must be string type (or untyped which defaults to string) + if propSchema.Type != "" && propSchema.Type != "string" { + return fmt.Errorf("elicit schema property %q has enum values but type is %q, enums are only supported for string type", propName, propSchema.Type) + } + // Enum values themselves are validated by the JSON schema library + // Validate enumNames if present - must match enum length + if propSchema.Extra != nil { + if enumNamesRaw, exists := propSchema.Extra["enumNames"]; exists { + // Type check enumNames - should be a slice + if enumNamesSlice, ok := enumNamesRaw.([]interface{}); ok { + if len(enumNamesSlice) != len(propSchema.Enum) { + return fmt.Errorf("elicit schema property %q has %d enum values but %d enumNames, they must match", propName, len(propSchema.Enum), len(enumNamesSlice)) + } + } else { + return fmt.Errorf("elicit schema property %q has invalid enumNames type, must be an array", propName) + } + } + } + return nil + } + + // Validate format if specified - only specific formats are allowed + if propSchema.Format != "" { + allowedFormats := map[string]bool{ + "email": true, + "uri": true, + "date": true, + "date-time": true, + } + if !allowedFormats[propSchema.Format] { + return fmt.Errorf("elicit schema property %q has unsupported format %q, only email, uri, date, and date-time are allowed", propName, propSchema.Format) + } + } + + // Validate minLength constraint if specified + if propSchema.MinLength != nil { + if *propSchema.MinLength < 0 { + return fmt.Errorf("elicit schema property %q has invalid minLength %d, must be non-negative", propName, *propSchema.MinLength) + } + } + + // Validate maxLength constraint if specified + if propSchema.MaxLength != nil { + if *propSchema.MaxLength < 0 { + return fmt.Errorf("elicit schema property %q has invalid maxLength %d, must be non-negative", propName, *propSchema.MaxLength) + } + // Check that maxLength >= minLength if both are specified + if propSchema.MinLength != nil && *propSchema.MaxLength < *propSchema.MinLength { + return fmt.Errorf("elicit schema property %q has maxLength %d less than minLength %d", propName, *propSchema.MaxLength, *propSchema.MinLength) + } + } + + return nil +} + +// validateElicitNumberProperty validates number and integer-type properties. +func validateElicitNumberProperty(propName string, propSchema *jsonschema.Schema) error { + if propSchema.Minimum != nil && propSchema.Maximum != nil { + if *propSchema.Maximum < *propSchema.Minimum { + return fmt.Errorf("elicit schema property %q has maximum %g less than minimum %g", propName, *propSchema.Maximum, *propSchema.Minimum) + } + } + + return nil +} + +// validateElicitBooleanProperty validates boolean-type properties. +func validateElicitBooleanProperty(propName string, propSchema *jsonschema.Schema) error { + // Validate default value if specified - must be a valid boolean + if propSchema.Default != nil { + var defaultValue bool + if err := json.Unmarshal(propSchema.Default, &defaultValue); err != nil { + return fmt.Errorf("elicit schema property %q has invalid default value, must be a boolean: %v", propName, err) + } } - return c.opts.CreateMessageHandler(ctx, cs, params) + + return nil } // AddSendingMiddleware wraps the current sending method handler using the provided @@ -246,7 +462,7 @@ func (c *Client) createMessage(ctx context.Context, cs *ClientSession, params *C // // Sending middleware is called when a request is sent. It is useful for tasks // such as tracing, metrics, and adding progress tokens. -func (c *Client) AddSendingMiddleware(middleware ...Middleware[*ClientSession]) { +func (c *Client) AddSendingMiddleware(middleware ...Middleware) { c.mu.Lock() defer c.mu.Unlock() addMiddleware(&c.sendingMethodHandler_, middleware) @@ -261,23 +477,30 @@ func (c *Client) AddSendingMiddleware(middleware ...Middleware[*ClientSession]) // // Receiving middleware is called when a request is received. It is useful for tasks // such as authentication, request logging and metrics. -func (c *Client) AddReceivingMiddleware(middleware ...Middleware[*ClientSession]) { +func (c *Client) AddReceivingMiddleware(middleware ...Middleware) { c.mu.Lock() defer c.mu.Unlock() addMiddleware(&c.receivingMethodHandler_, middleware) } // clientMethodInfos maps from the RPC method name to serverMethodInfos. +// +// The 'allowMissingParams' values are extracted from the protocol schema. +// TODO(rfindley): actually load and validate the protocol schema, rather than +// curating these method flags. var clientMethodInfos = map[string]methodInfo{ - methodComplete: newMethodInfo(sessionMethod((*ClientSession).Complete)), - methodPing: newMethodInfo(sessionMethod((*ClientSession).ping)), - methodListRoots: newMethodInfo(clientMethod((*Client).listRoots)), - methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage)), - notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler)), - notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler)), - notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler)), - notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler)), - notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler)), + methodComplete: newClientMethodInfo(clientSessionMethod((*ClientSession).Complete), 0), + methodPing: newClientMethodInfo(clientSessionMethod((*ClientSession).ping), missingParamsOK), + methodListRoots: newClientMethodInfo(clientMethod((*Client).listRoots), missingParamsOK), + methodCreateMessage: newClientMethodInfo(clientMethod((*Client).createMessage), 0), + methodElicit: newClientMethodInfo(clientMethod((*Client).elicit), missingParamsOK), + notificationCancelled: newClientMethodInfo(clientSessionMethod((*ClientSession).cancel), notification|missingParamsOK), + notificationToolListChanged: newClientMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK), + notificationPromptListChanged: newClientMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK), + notificationResourceListChanged: newClientMethodInfo(clientMethod((*Client).callResourceChangedHandler), notification|missingParamsOK), + notificationResourceUpdated: newClientMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), notification|missingParamsOK), + notificationLoggingMessage: newClientMethodInfo(clientMethod((*Client).callLoggingHandler), notification), + notificationProgress: newClientMethodInfo(clientSessionMethod((*ClientSession).callProgressNotificationHandler), notification), } func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { @@ -288,52 +511,69 @@ func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { return clientMethodInfos } -func (cs *ClientSession) handle(ctx context.Context, req *JSONRPCRequest) (any, error) { +func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { + if req.IsCall() { + jsonrpc2.Async(ctx) + } return handleReceive(ctx, cs, req) } -func (cs *ClientSession) sendingMethodHandler() methodHandler { +func (cs *ClientSession) sendingMethodHandler() MethodHandler { cs.client.mu.Lock() defer cs.client.mu.Unlock() return cs.client.sendingMethodHandler_ } -func (cs *ClientSession) receivingMethodHandler() methodHandler { +func (cs *ClientSession) receivingMethodHandler() MethodHandler { cs.client.mu.Lock() defer cs.client.mu.Unlock() return cs.client.receivingMethodHandler_ } -// getConn implements [session.getConn]. +// getConn implements [Session.getConn]. func (cs *ClientSession) getConn() *jsonrpc2.Connection { return cs.conn } func (*ClientSession) ping(context.Context, *PingParams) (*emptyResult, error) { return &emptyResult{}, nil } +// cancel is a placeholder: cancellation is handled the jsonrpc2 package. +// +// It should never be invoked in practice because cancellation is preempted, +// but having its signature here facilitates the construction of methodInfo +// that can be used to validate incoming cancellation notifications. +func (*ClientSession) cancel(context.Context, *CancelledParams) (Result, error) { + return nil, nil +} + +func newClientRequest[P Params](cs *ClientSession, params P) *ClientRequest[P] { + return &ClientRequest[P]{Session: cs, Params: params} +} + // Ping makes an MCP "ping" request to the server. func (cs *ClientSession) Ping(ctx context.Context, params *PingParams) error { - _, err := handleSend[*emptyResult](ctx, cs, methodPing, orZero[Params](params)) + _, err := handleSend[*emptyResult](ctx, methodPing, newClientRequest(cs, orZero[Params](params))) return err } // ListPrompts lists prompts that are currently available on the server. func (cs *ClientSession) ListPrompts(ctx context.Context, params *ListPromptsParams) (*ListPromptsResult, error) { - return handleSend[*ListPromptsResult](ctx, cs, methodListPrompts, orZero[Params](params)) + return handleSend[*ListPromptsResult](ctx, methodListPrompts, newClientRequest(cs, orZero[Params](params))) } // GetPrompt gets a prompt from the server. func (cs *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams) (*GetPromptResult, error) { - return handleSend[*GetPromptResult](ctx, cs, methodGetPrompt, orZero[Params](params)) + return handleSend[*GetPromptResult](ctx, methodGetPrompt, newClientRequest(cs, orZero[Params](params))) } // ListTools lists tools that are currently available on the server. func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { - return handleSend[*ListToolsResult](ctx, cs, methodListTools, orZero[Params](params)) + return handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params))) } -// CallTool calls the tool with the given name and arguments. -// The arguments can be any value that marshals into a JSON object. +// CallTool calls the tool with the given parameters. +// +// The params.Arguments can be any value that marshals into a JSON object. func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (*CallToolResult, error) { if params == nil { params = new(CallToolParams) @@ -342,54 +582,87 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) ( // Avoid sending nil over the wire. params.Arguments = map[string]any{} } - return handleSend[*CallToolResult](ctx, cs, methodCallTool, params) + return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) } -func (cs *ClientSession) SetLevel(ctx context.Context, params *SetLevelParams) error { - _, err := handleSend[*emptyResult](ctx, cs, methodSetLevel, orZero[Params](params)) +func (cs *ClientSession) SetLoggingLevel(ctx context.Context, params *SetLoggingLevelParams) error { + _, err := handleSend[*emptyResult](ctx, methodSetLevel, newClientRequest(cs, orZero[Params](params))) return err } // ListResources lists the resources that are currently available on the server. func (cs *ClientSession) ListResources(ctx context.Context, params *ListResourcesParams) (*ListResourcesResult, error) { - return handleSend[*ListResourcesResult](ctx, cs, methodListResources, orZero[Params](params)) + return handleSend[*ListResourcesResult](ctx, methodListResources, newClientRequest(cs, orZero[Params](params))) } // ListResourceTemplates lists the resource templates that are currently available on the server. func (cs *ClientSession) ListResourceTemplates(ctx context.Context, params *ListResourceTemplatesParams) (*ListResourceTemplatesResult, error) { - return handleSend[*ListResourceTemplatesResult](ctx, cs, methodListResourceTemplates, orZero[Params](params)) + return handleSend[*ListResourceTemplatesResult](ctx, methodListResourceTemplates, newClientRequest(cs, orZero[Params](params))) } // ReadResource asks the server to read a resource and return its contents. func (cs *ClientSession) ReadResource(ctx context.Context, params *ReadResourceParams) (*ReadResourceResult, error) { - return handleSend[*ReadResourceResult](ctx, cs, methodReadResource, orZero[Params](params)) + return handleSend[*ReadResourceResult](ctx, methodReadResource, newClientRequest(cs, orZero[Params](params))) } func (cs *ClientSession) Complete(ctx context.Context, params *CompleteParams) (*CompleteResult, error) { - return handleSend[*CompleteResult](ctx, cs, methodComplete, orZero[Params](params)) + return handleSend[*CompleteResult](ctx, methodComplete, newClientRequest(cs, orZero[Params](params))) +} + +// Subscribe sends a "resources/subscribe" request to the server, asking for +// notifications when the specified resource changes. +func (cs *ClientSession) Subscribe(ctx context.Context, params *SubscribeParams) error { + _, err := handleSend[*emptyResult](ctx, methodSubscribe, newClientRequest(cs, orZero[Params](params))) + return err } -func (c *Client) callToolChangedHandler(ctx context.Context, s *ClientSession, params *ToolListChangedParams) (Result, error) { - return callNotificationHandler(ctx, c.opts.ToolListChangedHandler, s, params) +// Unsubscribe sends a "resources/unsubscribe" request to the server, cancelling +// a previous subscription. +func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribeParams) error { + _, err := handleSend[*emptyResult](ctx, methodUnsubscribe, newClientRequest(cs, orZero[Params](params))) + return err } -func (c *Client) callPromptChangedHandler(ctx context.Context, s *ClientSession, params *PromptListChangedParams) (Result, error) { - return callNotificationHandler(ctx, c.opts.PromptListChangedHandler, s, params) +func (c *Client) callToolChangedHandler(ctx context.Context, req *ToolListChangedRequest) (Result, error) { + if h := c.opts.ToolListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil } -func (c *Client) callResourceChangedHandler(ctx context.Context, s *ClientSession, params *ResourceListChangedParams) (Result, error) { - return callNotificationHandler(ctx, c.opts.ResourceListChangedHandler, s, params) +func (c *Client) callPromptChangedHandler(ctx context.Context, req *PromptListChangedRequest) (Result, error) { + if h := c.opts.PromptListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil } -func (c *Client) callLoggingHandler(ctx context.Context, cs *ClientSession, params *LoggingMessageParams) (Result, error) { +func (c *Client) callResourceChangedHandler(ctx context.Context, req *ResourceListChangedRequest) (Result, error) { + if h := c.opts.ResourceListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (c *Client) callResourceUpdatedHandler(ctx context.Context, req *ResourceUpdatedNotificationRequest) (Result, error) { + if h := c.opts.ResourceUpdatedHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + +func (c *Client) callLoggingHandler(ctx context.Context, req *LoggingMessageRequest) (Result, error) { if h := c.opts.LoggingMessageHandler; h != nil { - h(ctx, cs, params) + h(ctx, req) } return nil, nil } func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, params *ProgressNotificationParams) (Result, error) { - return callNotificationHandler(ctx, cs.client.opts.ProgressNotificationHandler, cs, params) + if h := cs.client.opts.ProgressNotificationHandler; h != nil { + h(ctx, clientRequestFor(cs, params)) + } + return nil, nil } // NotifyProgress sends a progress notification from the client to the server @@ -397,7 +670,7 @@ func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, pa // This can be used if the client is performing a long-running task that was // initiated by the server func (cs *ClientSession) NotifyProgress(ctx context.Context, params *ProgressNotificationParams) error { - return handleNotify(ctx, cs, notificationProgress, params) + return handleNotify(ctx, notificationProgress, newClientRequest(cs, orZero[Params](params))) } // Tools provides an iterator for all tools available on the server, diff --git a/mcp/client_example_test.go b/mcp/client_example_test.go new file mode 100644 index 00000000..bba3da44 --- /dev/null +++ b/mcp/client_example_test.go @@ -0,0 +1,139 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "context" + "fmt" + "log" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// !+roots + +func Example_roots() { + ctx := context.Background() + + // Create a client with a single root. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + c.AddRoots(&mcp.Root{URI: "file://a"}) + + // Now create a server with a handler to receive notifications about roots. + rootsChanged := make(chan struct{}) + handleRootsChanged := func(ctx context.Context, req *mcp.RootsListChangedRequest) { + rootList, err := req.Session.ListRoots(ctx, nil) + if err != nil { + log.Fatal(err) + } + var roots []string + for _, root := range rootList.Roots { + roots = append(roots, root.URI) + } + fmt.Println(roots) + close(rootsChanged) + } + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, &mcp.ServerOptions{ + RootsListChangedHandler: handleRootsChanged, + }) + + // Connect the server and client... + t1, t2 := mcp.NewInMemoryTransports() + if _, err := s.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + + clientSession, err := c.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer clientSession.Close() + + // ...and add a root. The server is notified about the change. + c.AddRoots(&mcp.Root{URI: "file://b"}) + <-rootsChanged + // Output: [file://a file://b] +} + +// !-roots + +// !+sampling + +func Example_sampling() { + ctx := context.Background() + + // Create a client with a sampling handler. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, &mcp.ClientOptions{ + CreateMessageHandler: func(_ context.Context, req *mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + return &mcp.CreateMessageResult{ + Content: &mcp.TextContent{ + Text: "would have created a message", + }, + }, nil + }, + }) + + // Connect the server and client... + ct, st := mcp.NewInMemoryTransports() + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + session, err := s.Connect(ctx, st, nil) + if err != nil { + log.Fatal(err) + } + defer session.Close() + + if _, err := c.Connect(ctx, ct, nil); err != nil { + log.Fatal(err) + } + + msg, err := session.CreateMessage(ctx, &mcp.CreateMessageParams{}) + if err != nil { + log.Fatal(err) + } + fmt.Println(msg.Content.(*mcp.TextContent).Text) + // Output: would have created a message +} + +// !-sampling + +// !+elicitation + +func Example_elicitation() { + ctx := context.Background() + ct, st := mcp.NewInMemoryTransports() + + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + log.Fatal(err) + } + defer ss.Close() + + c := mcp.NewClient(testImpl, &mcp.ClientOptions{ + ElicitationHandler: func(context.Context, *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + return &mcp.ElicitResult{Action: "accept", Content: map[string]any{"test": "value"}}, nil + }, + }) + if _, err := c.Connect(ctx, ct, nil); err != nil { + log.Fatal(err) + } + res, err := ss.Elicit(ctx, &mcp.ElicitParams{ + Message: "This should fail", + RequestedSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "test": {Type: "string"}, + }, + }, + }) + if err != nil { + log.Fatal(err) + } + fmt.Println(res.Content["test"]) + // Output: value +} + +// !-elicitation diff --git a/mcp/client_list_test.go b/mcp/client_list_test.go index 7e6da95a..0183a733 100644 --- a/mcp/client_list_test.go +++ b/mcp/client_list_test.go @@ -7,27 +7,43 @@ package mcp_test import ( "context" "iter" + "log" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" ) func TestList(t *testing.T) { ctx := context.Background() - clientSession, serverSession, server := createSessions(ctx) - defer clientSession.Close() + server := mcp.NewServer(testImpl, nil) + client := mcp.NewClient(testImpl, nil) + serverTransport, clientTransport := mcp.NewInMemoryTransports() + serverSession, err := server.Connect(ctx, serverTransport, nil) + if err != nil { + log.Fatal(err) + } defer serverSession.Close() + clientSession, err := client.Connect(ctx, clientTransport, nil) + if err != nil { + log.Fatal(err) + } + defer clientSession.Close() t.Run("tools", func(t *testing.T) { - toolA := mcp.NewServerTool("apple", "apple tool", SayHi) - toolB := mcp.NewServerTool("banana", "banana tool", SayHi) - toolC := mcp.NewServerTool("cherry", "cherry tool", SayHi) - tools := []*mcp.ServerTool{toolA, toolB, toolC} - wantTools := []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool} - server.AddTools(tools...) + var wantTools []*mcp.Tool + for _, name := range []string{"apple", "banana", "cherry"} { + tt := &mcp.Tool{Name: name, Description: name + " tool"} + mcp.AddTool(server, tt, SayHi) + is, err := jsonschema.For[SayHiParams](nil) + if err != nil { + t.Fatal(err) + } + tt.InputSchema = is + wantTools = append(wantTools, tt) + } t.Run("list", func(t *testing.T) { res, err := clientSession.ListTools(ctx, nil) if err != nil { @@ -38,17 +54,18 @@ func TestList(t *testing.T) { } }) t.Run("iterator", func(t *testing.T) { - testIterator(ctx, t, clientSession.Tools(ctx, nil), wantTools) + testIterator(t, clientSession.Tools(ctx, nil), wantTools) }) }) t.Run("resources", func(t *testing.T) { - resourceA := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://apple"}} - resourceB := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://banana"}} - resourceC := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://cherry"}} - wantResources := []*mcp.Resource{resourceA.Resource, resourceB.Resource, resourceC.Resource} - resources := []*mcp.ServerResource{resourceA, resourceB, resourceC} - server.AddResources(resources...) + var wantResources []*mcp.Resource + for _, name := range []string{"apple", "banana", "cherry"} { + r := &mcp.Resource{URI: "http://" + name} + wantResources = append(wantResources, r) + server.AddResource(r, nil) + } + t.Run("list", func(t *testing.T) { res, err := clientSession.ListResources(ctx, nil) if err != nil { @@ -59,20 +76,17 @@ func TestList(t *testing.T) { } }) t.Run("iterator", func(t *testing.T) { - testIterator(ctx, t, clientSession.Resources(ctx, nil), wantResources) + testIterator(t, clientSession.Resources(ctx, nil), wantResources) }) }) t.Run("templates", func(t *testing.T) { - resourceTmplA := &mcp.ServerResourceTemplate{ResourceTemplate: &mcp.ResourceTemplate{URITemplate: "http://apple/{x}"}} - resourceTmplB := &mcp.ServerResourceTemplate{ResourceTemplate: &mcp.ResourceTemplate{URITemplate: "http://banana/{x}"}} - resourceTmplC := &mcp.ServerResourceTemplate{ResourceTemplate: &mcp.ResourceTemplate{URITemplate: "http://cherry/{x}"}} - wantResourceTemplates := []*mcp.ResourceTemplate{ - resourceTmplA.ResourceTemplate, resourceTmplB.ResourceTemplate, - resourceTmplC.ResourceTemplate, + var wantResourceTemplates []*mcp.ResourceTemplate + for _, name := range []string{"apple", "banana", "cherry"} { + rt := &mcp.ResourceTemplate{URITemplate: "http://" + name + "/{x}"} + wantResourceTemplates = append(wantResourceTemplates, rt) + server.AddResourceTemplate(rt, nil) } - resourceTemplates := []*mcp.ServerResourceTemplate{resourceTmplA, resourceTmplB, resourceTmplC} - server.AddResourceTemplates(resourceTemplates...) t.Run("list", func(t *testing.T) { res, err := clientSession.ListResourceTemplates(ctx, nil) if err != nil { @@ -83,17 +97,17 @@ func TestList(t *testing.T) { } }) t.Run("ResourceTemplatesIterator", func(t *testing.T) { - testIterator(ctx, t, clientSession.ResourceTemplates(ctx, nil), wantResourceTemplates) + testIterator(t, clientSession.ResourceTemplates(ctx, nil), wantResourceTemplates) }) }) t.Run("prompts", func(t *testing.T) { - promptA := newServerPrompt("apple", "apple prompt") - promptB := newServerPrompt("banana", "banana prompt") - promptC := newServerPrompt("cherry", "cherry prompt") - wantPrompts := []*mcp.Prompt{promptA.Prompt, promptB.Prompt, promptC.Prompt} - prompts := []*mcp.ServerPrompt{promptA, promptB, promptC} - server.AddPrompts(prompts...) + var wantPrompts []*mcp.Prompt + for _, name := range []string{"apple", "banana", "cherry"} { + p := &mcp.Prompt{Name: name, Description: name + " prompt"} + wantPrompts = append(wantPrompts, p) + server.AddPrompt(p, testPromptHandler) + } t.Run("list", func(t *testing.T) { res, err := clientSession.ListPrompts(ctx, nil) if err != nil { @@ -104,12 +118,12 @@ func TestList(t *testing.T) { } }) t.Run("iterator", func(t *testing.T) { - testIterator(ctx, t, clientSession.Prompts(ctx, nil), wantPrompts) + testIterator(t, clientSession.Prompts(ctx, nil), wantPrompts) }) }) } -func testIterator[T any](ctx context.Context, t *testing.T, seq iter.Seq2[*T, error], want []*T) { +func testIterator[T any](t *testing.T, seq iter.Seq2[*T, error], want []*T) { t.Helper() var got []*T for x, err := range seq { @@ -123,14 +137,6 @@ func testIterator[T any](ctx context.Context, t *testing.T, seq iter.Seq2[*T, er } } -// testPromptHandler is used for type inference newServerPrompt. -func testPromptHandler(context.Context, *mcp.ServerSession, *mcp.GetPromptParams) (*mcp.GetPromptResult, error) { +func testPromptHandler(context.Context, *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { panic("not implemented") } - -func newServerPrompt(name, desc string) *mcp.ServerPrompt { - return &mcp.ServerPrompt{ - Prompt: &mcp.Prompt{Name: name, Description: desc}, - Handler: testPromptHandler, - } -} diff --git a/mcp/client_test.go b/mcp/client_test.go index 73fe09e6..eaeedc81 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -11,7 +11,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/google/jsonschema-go/jsonschema" ) type Item struct { @@ -190,3 +190,48 @@ func TestClientPaginateVariousPageSizes(t *testing.T) { }) } } + +func TestClientCapabilities(t *testing.T) { + testCases := []struct { + name string + configureClient func(s *Client) + clientOpts ClientOptions + wantCapabilities *ClientCapabilities + }{ + { + name: "With initial capabilities", + configureClient: func(s *Client) {}, + wantCapabilities: &ClientCapabilities{ + Roots: struct { + ListChanged bool "json:\"listChanged,omitempty\"" + }{ListChanged: true}, + }, + }, + { + name: "With sampling", + configureClient: func(s *Client) {}, + clientOpts: ClientOptions{ + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { + return nil, nil + }, + }, + wantCapabilities: &ClientCapabilities{ + Roots: struct { + ListChanged bool "json:\"listChanged,omitempty\"" + }{ListChanged: true}, + Sampling: &SamplingCapabilities{}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + client := NewClient(testImpl, &tc.clientOpts) + tc.configureClient(client) + gotCapabilities := client.capabilities() + if diff := cmp.Diff(tc.wantCapabilities, gotCapabilities); diff != "" { + t.Errorf("capabilities() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/mcp/cmd.go b/mcp/cmd.go index 163bb0ca..b531eaf1 100644 --- a/mcp/cmd.go +++ b/mcp/cmd.go @@ -13,44 +13,46 @@ import ( "time" ) +var defaultTerminateDuration = 5 * time.Second // mutable for testing + // A CommandTransport is a [Transport] that runs a command and communicates // with it over stdin/stdout, using newline-delimited JSON. type CommandTransport struct { - cmd *exec.Cmd -} - -// NewCommandTransport returns a [CommandTransport] that runs the given command -// and communicates with it over stdin/stdout. -// -// The resulting transport takes ownership of the command, starting it during -// [CommandTransport.Connect], and stopping it when the connection is closed. -func NewCommandTransport(cmd *exec.Cmd) *CommandTransport { - return &CommandTransport{cmd} + Command *exec.Cmd + // TerminateDuration controls how long Close waits after closing stdin + // for the process to exit before sending SIGTERM. + // If zero or negative, the default of 5s is used. + TerminateDuration time.Duration } // Connect starts the command, and connects to it over stdin/stdout. func (t *CommandTransport) Connect(ctx context.Context) (Connection, error) { - stdout, err := t.cmd.StdoutPipe() + stdout, err := t.Command.StdoutPipe() if err != nil { return nil, err } stdout = io.NopCloser(stdout) // close the connection by closing stdin, not stdout - stdin, err := t.cmd.StdinPipe() + stdin, err := t.Command.StdinPipe() if err != nil { return nil, err } - if err := t.cmd.Start(); err != nil { + if err := t.Command.Start(); err != nil { return nil, err } - return newIOConn(&pipeRWC{t.cmd, stdout, stdin}), nil + td := t.TerminateDuration + if td <= 0 { + td = defaultTerminateDuration + } + return newIOConn(&pipeRWC{t.Command, stdout, stdin, td}), nil } // A pipeRWC is an io.ReadWriteCloser that communicates with a subprocess over // stdin/stdout pipes. type pipeRWC struct { - cmd *exec.Cmd - stdout io.ReadCloser - stdin io.WriteCloser + cmd *exec.Cmd + stdout io.ReadCloser + stdin io.WriteCloser + terminateDuration time.Duration } func (s *pipeRWC) Read(p []byte) (n int, err error) { @@ -81,7 +83,7 @@ func (s *pipeRWC) Close() error { select { case err := <-resChan: return err, true - case <-time.After(5 * time.Second): + case <-time.After(s.terminateDuration): } return nil, false } diff --git a/mcp/cmd_export_test.go b/mcp/cmd_export_test.go new file mode 100644 index 00000000..331e8bd0 --- /dev/null +++ b/mcp/cmd_export_test.go @@ -0,0 +1,20 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import "time" + +// This file exports some helpers for mutating internals of the command +// transport for testing. + +// SetDefaultTerminateDuration sets the default command terminate duration, +// and returns a function to reset it to the default. +func SetDefaultTerminateDuration(d time.Duration) (reset func()) { + initial := defaultTerminateDuration + defaultTerminateDuration = d + return func() { + defaultTerminateDuration = initial + } +} diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index f66423d6..cbaadcb0 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -6,11 +6,15 @@ package mcp_test import ( "context" + "errors" "log" "os" "os/exec" + "os/signal" "runtime" + "syscall" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -18,66 +22,325 @@ import ( const runAsServer = "_MCP_RUN_AS_SERVER" +type SayHiParams struct { + Name string `json:"name"` +} + +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args SayHiParams) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Hi " + args.Name}, + }, + }, nil, nil +} + func TestMain(m *testing.M) { - if os.Getenv(runAsServer) != "" { + // If the runAsServer variable is set, execute the relevant serverFunc + // instead of running tests (aka the fork and exec trick). + if name := os.Getenv(runAsServer); name != "" { + run := serverFuncs[name] + if run == nil { + log.Fatalf("Unknown server %q", name) + } os.Unsetenv(runAsServer) - runServer() + run() return } os.Exit(m.Run()) } +// serverFuncs defines server functions that may be run as subprocesses via +// [TestMain]. +var serverFuncs = map[string]func(){ + "default": runServer, + "cancelContext": runCancelContextServer, +} + func runServer() { ctx := context.Background() - server := mcp.NewServer("greeter", "v0.0.1", nil) - server.AddTools(mcp.NewServerTool("greet", "say hi", SayHi)) + server := mcp.NewServer(testImpl, nil) + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) + if err := server.Run(ctx, &mcp.StdioTransport{}); err != nil { + log.Fatal(err) + } +} - if err := server.Run(ctx, mcp.NewStdioTransport()); err != nil { +func runCancelContextServer() { + ctx, done := signal.NotifyContext(context.Background(), syscall.SIGINT) + defer done() + + server := mcp.NewServer(testImpl, nil) + if err := server.Run(ctx, &mcp.StdioTransport{}); err != nil { log.Fatal(err) } } -func TestCmdTransport(t *testing.T) { - // Conservatively, limit to major OS where we know that os.Exec is - // supported. - switch runtime.GOOS { - case "darwin", "linux", "windows": - default: - t.Skip("unsupported OS") +func TestServerRunContextCancel(t *testing.T) { + server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v0.0.1"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverTransport, clientTransport := mcp.NewInMemoryTransports() + + // run the server and capture the exit error + onServerExit := make(chan error) + go func() { + onServerExit <- server.Run(ctx, serverTransport) + }() + + // send a ping to the server to ensure it's running + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + session, err := client.Connect(ctx, clientTransport, nil) + if err != nil { + t.Fatal(err) + } + if err := session.Ping(context.Background(), nil); err != nil { + t.Fatal(err) } + // cancel the context to stop the server + cancel() + + // wait for the server to exit + // TODO: use synctest when availble + select { + case <-time.After(5 * time.Second): + t.Fatal("server did not exit after context cancellation") + case err := <-onServerExit: + if !errors.Is(err, context.Canceled) { + t.Fatalf("server did not exit after context cancellation, got error: %v", err) + } + } +} + +func TestServerInterrupt(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("requires POSIX signals") + } + requireExec(t) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - exe, err := os.Executable() + cmd := createServerCommand(t, "default") + + client := mcp.NewClient(testImpl, nil) + _, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil) if err != nil { t.Fatal(err) } - cmd := exec.Command(exe) - cmd.Env = append(os.Environ(), runAsServer+"=true") - client := mcp.NewClient("client", "v0.0.1", nil) - session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) + // get a signal when the server process exits + onExit := make(chan struct{}) + go func() { + cmd.Process.Wait() + close(onExit) + }() + + // send a signal to the server process to terminate it + if err := cmd.Process.Signal(os.Interrupt); err != nil { + t.Fatal(err) + } + + // wait for the server to exit + // TODO: use synctest when available + select { + case <-time.After(5 * time.Second): + t.Fatal("server did not exit after SIGINT") + case <-onExit: + } +} + +func TestStdioContextCancellation(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("requires POSIX signals") + } + requireExec(t) + + // This test is a variant of TestServerInterrupt reproducing the conditions + // of #224, where interrupt failed to shut down the server because reads of + // Stdin were not unblocked. + + cmd := createServerCommand(t, "cancelContext") + // Creating a stdin pipe causes os.Stdin.Close to not immediately unblock + // pending reads. + _, _ = cmd.StdinPipe() + + // Just Start the command, rather than connecting to the server, because we + // don't want the client connection to indirectly flush stdin through writes. + if err := cmd.Start(); err != nil { + t.Fatalf("starting command: %v", err) + } + + // Sleep to make it more likely that the server is blocked in the read loop. + // + // This sleep isn't necessary for the test to pass, but *was* necessary for + // it to fail, before closing was fixed. Unfortunately, it is too invasive a + // change to have the jsonrpc2 package signal across packages when it is + // actually blocked in its read loop. + time.Sleep(100 * time.Millisecond) + + onExit := make(chan struct{}) + go func() { + cmd.Process.Wait() + close(onExit) + }() + + if err := cmd.Process.Signal(os.Interrupt); err != nil { + t.Fatal(err) + } + + select { + case <-time.After(5 * time.Second): + t.Fatal("server did not exit after SIGINT") + case <-onExit: + t.Logf("done.") + } +} + +func TestCmdTransport(t *testing.T) { + requireExec(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cmd := createServerCommand(t, "default") + + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + session, err := client.Connect(ctx, &mcp.CommandTransport{Command: cmd}, nil) if err != nil { - log.Fatal(err) + t.Fatal(err) } got, err := session.CallTool(ctx, &mcp.CallToolParams{ Name: "greet", Arguments: map[string]any{"name": "user"}, }) if err != nil { - log.Fatal(err) + t.Fatal(err) } want := &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: "Hi user"}, }, } - if diff := cmp.Diff(want, got); diff != "" { + if diff := cmp.Diff(want, got, ctrCmpOpts...); diff != "" { t.Errorf("greet returned unexpected content (-want +got):\n%s", diff) } if err := session.Close(); err != nil { t.Fatalf("closing server: %v", err) } } + +// createServerCommand creates a command to fork and exec the test binary as an +// MCP server. +// +// serverName must refer to an entry in the [serverFuncs] map. +func createServerCommand(t *testing.T, serverName string) *exec.Cmd { + t.Helper() + + exe, err := os.Executable() + if err != nil { + t.Fatal(err) + } + cmd := exec.Command(exe) + cmd.Env = append(os.Environ(), runAsServer+"="+serverName) + + return cmd +} + +func TestCommandTransportTerminateDuration(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("requires POSIX signals") + } + requireExec(t) + + // Unfortunately, since it does I/O, this test needs to rely on timing (we + // can't use synctest). However, we can still decreate the default + // termination duration to speed up the test. + const defaultDur = 50 * time.Millisecond + defer mcp.SetDefaultTerminateDuration(defaultDur)() + + tests := []struct { + name string + duration time.Duration + wantMinDuration time.Duration + wantMaxDuration time.Duration + }{ + { + name: "default duration (zero)", + duration: 0, + wantMinDuration: defaultDur, + wantMaxDuration: 1 * time.Second, // default + buffer + }, + { + name: "below minimum duration", + duration: -500 * time.Millisecond, + wantMinDuration: defaultDur, + wantMaxDuration: 1 * time.Second, // should use default + buffer + }, + { + name: "custom valid duration", + duration: 200 * time.Millisecond, + wantMinDuration: 200 * time.Millisecond, + wantMaxDuration: 1 * time.Second, // custom + buffer + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Use a command that won't exit when stdin is closed + cmd := exec.Command("sleep", "20") + transport := &mcp.CommandTransport{ + Command: cmd, + TerminateDuration: tt.duration, + } + + conn, err := transport.Connect(ctx) + if err != nil { + t.Fatal(err) + } + + start := time.Now() + err = conn.Close() + elapsed := time.Since(start) + + if err != nil { + var exitErr *exec.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("Close() failed with unexpected error: %v", err) + } + } + if elapsed < tt.wantMinDuration { + t.Errorf("Close() took %v, expected at least %v", elapsed, tt.wantMinDuration) + } + if elapsed > tt.wantMaxDuration { + t.Errorf("Close() took %v, expected at most %v", elapsed, tt.wantMaxDuration) + } + + // Ensure the process was actually terminated + if cmd.Process != nil { + cmd.Process.Kill() + } + }) + } +} + +func requireExec(t *testing.T) { + t.Helper() + + // Conservatively, limit to major OS where we know that os.Exec is + // supported. + switch runtime.GOOS { + case "darwin", "linux", "windows": + default: + t.Skip("unsupported OS") + } +} + +var testImpl = &mcp.Implementation{Name: "test", Version: "v1.0.0"} diff --git a/mcp/conformance_test.go b/mcp/conformance_test.go index 3be54a3e..3393efcb 100644 --- a/mcp/conformance_test.go +++ b/mcp/conformance_test.go @@ -8,6 +8,7 @@ package mcp import ( "bytes" + "context" "encoding/json" "errors" "flag" @@ -19,10 +20,13 @@ import ( "strings" "testing" "testing/synctest" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" "golang.org/x/tools/txtar" ) @@ -46,12 +50,12 @@ var update = flag.Bool("update", false, "if set, update conformance test data") // with -update to have the test runner update the expected output, which may // be client or server depending on the perspective of the test. type conformanceTest struct { - name string // test name - path string // path to test file - archive *txtar.Archive // raw archive, for updating - tools, prompts, resources []string // named features to include - client []JSONRPCMessage // client messages - server []JSONRPCMessage // server messages + name string // test name + path string // path to test file + archive *txtar.Archive // raw archive, for updating + tools, prompts, resources []string // named features to include + client []jsonrpc.Message // client messages + server []jsonrpc.Message // server messages } // TODO(rfindley): add client conformance tests. @@ -95,20 +99,93 @@ func TestServerConformance(t *testing.T) { } } +type structuredInput struct { + In string `jsonschema:"the input"` +} + +type structuredOutput struct { + Out string `jsonschema:"the output"` +} + +func structuredTool(ctx context.Context, req *CallToolRequest, args *structuredInput) (*CallToolResult, *structuredOutput, error) { + return nil, &structuredOutput{"Ack " + args.In}, nil +} + +type tomorrowInput struct { + Now time.Time +} + +type tomorrowOutput struct { + Tomorrow time.Time +} + +func tomorrowTool(ctx context.Context, req *CallToolRequest, args tomorrowInput) (*CallToolResult, tomorrowOutput, error) { + return nil, tomorrowOutput{args.Now.Add(24 * time.Hour)}, nil +} + +type incInput struct { + X int `json:"x,omitempty"` +} + +type incOutput struct { + Y int `json:"y"` +} + +func incTool(_ context.Context, _ *CallToolRequest, args incInput) (*CallToolResult, incOutput, error) { + return nil, incOutput{args.X + 1}, nil +} + // runServerTest runs the server conformance test. // It must be executed in a synctest bubble. func runServerTest(t *testing.T, test *conformanceTest) { ctx := t.Context() // Construct the server based on features listed in the test. - s := NewServer("testServer", "v1.0.0", nil) - add(tools, s.AddTools, test.tools...) - add(prompts, s.AddPrompts, test.prompts...) - add(resources, s.AddResources, test.resources...) + s := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + for _, tn := range test.tools { + switch tn { + case "greet": + AddTool(s, &Tool{ + Name: "greet", + Description: "say hi", + }, sayHi) + case "structured": + AddTool(s, &Tool{Name: "structured"}, structuredTool) + case "tomorrow": + AddTool(s, &Tool{Name: "tomorrow"}, tomorrowTool) + case "inc": + inSchema, err := jsonschema.For[incInput](nil) + if err != nil { + t.Fatal(err) + } + inSchema.Properties["x"].Default = json.RawMessage(`6`) + AddTool(s, &Tool{Name: "inc", InputSchema: inSchema}, incTool) + default: + t.Fatalf("unknown tool %q", tn) + } + } + for _, pn := range test.prompts { + switch pn { + case "code_review": + s.AddPrompt(codeReviewPrompt, codReviewPromptHandler) + default: + t.Fatalf("unknown prompt %q", pn) + } + } + for _, rn := range test.resources { + switch rn { + case "info.txt": + s.AddResource(resource1, readHandler) + case "info": + s.AddResource(resource3, handleEmbeddedResource) + default: + t.Fatalf("unknown resource %q", rn) + } + } // Connect the server, and connect the client stream, // but don't connect an actual client. cTransport, sTransport := NewInMemoryTransports() - ss, err := s.Connect(ctx, sTransport) + ss, err := s.Connect(ctx, sTransport, nil) if err != nil { t.Fatal(err) } @@ -117,24 +194,24 @@ func runServerTest(t *testing.T, test *conformanceTest) { t.Fatal(err) } - writeMsg := func(msg JSONRPCMessage) { + writeMsg := func(msg jsonrpc.Message) { if err := cStream.Write(ctx, msg); err != nil { t.Fatalf("Write failed: %v", err) } } var ( - serverMessages []JSONRPCMessage - outRequests []*JSONRPCRequest - outResponses []*JSONRPCResponse + serverMessages []jsonrpc.Message + outRequests []*jsonrpc.Request + outResponses []*jsonrpc.Response ) // Separate client requests and responses; we use them differently. for _, msg := range test.client { switch msg := msg.(type) { - case *JSONRPCRequest: + case *jsonrpc.Request: outRequests = append(outRequests, msg) - case *JSONRPCResponse: + case *jsonrpc.Response: outResponses = append(outResponses, msg) default: t.Fatalf("bad message type %T", msg) @@ -143,7 +220,7 @@ func runServerTest(t *testing.T, test *conformanceTest) { // nextResponse handles incoming requests and notifications, and returns the // next incoming response. - nextResponse := func() (*JSONRPCResponse, error, bool) { + nextResponse := func() (*jsonrpc.Response, error, bool) { for { msg, err := cStream.Read(ctx) if err != nil { @@ -156,7 +233,7 @@ func runServerTest(t *testing.T, test *conformanceTest) { return nil, err, false } serverMessages = append(serverMessages, msg) - if req, ok := msg.(*JSONRPCRequest); ok && req.ID.IsValid() { + if req, ok := msg.(*jsonrpc.Request); ok && req.IsCall() { // Pair up the next outgoing response with this request. // We assume requests arrive in the same order every time. if len(outResponses) == 0 { @@ -167,15 +244,15 @@ func runServerTest(t *testing.T, test *conformanceTest) { outResponses = outResponses[1:] continue } - return msg.(*JSONRPCResponse), nil, true + return msg.(*jsonrpc.Response), nil, true } } // Synthetic peer interacts with real peer. for _, req := range outRequests { writeMsg(req) - if req.ID.IsValid() { - // A request (as opposed to a notification). Wait for the response. + if req.IsCall() { + // A call (as opposed to a notification). Wait for the response. res, err, ok := nextResponse() if err != nil { t.Fatalf("reading server messages failed: %v", err) @@ -191,7 +268,7 @@ func runServerTest(t *testing.T, test *conformanceTest) { // There might be more notifications or requests, but there shouldn't be more // responses. // Run this in a goroutine so the current thread can wait for it. - var extra *JSONRPCResponse + var extra *jsonrpc.Response go func() { extra, err, _ = nextResponse() }() @@ -240,8 +317,8 @@ func runServerTest(t *testing.T, test *conformanceTest) { t.Fatalf("os.WriteFile(%q) failed: %v", test.path, err) } } else { - // JSONRPCMessages are not comparable, so we instead compare lines of JSON. - transform := cmpopts.AcyclicTransformer("toJSON", func(msg JSONRPCMessage) []string { + // jsonrpc.Messages are not comparable, so we instead compare lines of JSON. + transform := cmpopts.AcyclicTransformer("toJSON", func(msg jsonrpc.Message) []string { encoded, err := jsonrpc2.EncodeIndent(msg, "", "\t") if err != nil { t.Fatal(err) @@ -271,9 +348,9 @@ func loadConformanceTest(dir, path string) (*conformanceTest, error) { } // decodeMessages loads JSON-RPC messages from the archive file. - decodeMessages := func(data []byte) ([]JSONRPCMessage, error) { + decodeMessages := func(data []byte) ([]jsonrpc.Message, error) { dec := json.NewDecoder(bytes.NewReader(data)) - var res []JSONRPCMessage + var res []jsonrpc.Message for dec.More() { var raw json.RawMessage if err := dec.Decode(&raw); err != nil { diff --git a/mcp/content.go b/mcp/content.go index ed7f6f99..108b0271 100644 --- a/mcp/content.go +++ b/mcp/content.go @@ -2,6 +2,9 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. +// TODO(findleyr): update JSON marshalling of all content types to preserve required fields. +// (See [TextContent.MarshalJSON], which handles this for text content). + package mcp import ( @@ -25,12 +28,19 @@ type TextContent struct { } func (c *TextContent) MarshalJSON() ([]byte, error) { - return json.Marshal(&wireContent{ + // Custom wire format to ensure the required "text" field is always included, even when empty. + wire := struct { + Type string `json:"type"` + Text string `json:"text"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` + }{ Type: "text", Text: c.Text, Meta: c.Meta, Annotations: c.Annotations, - }) + } + return json.Marshal(wire) } func (c *TextContent) fromWire(wire *wireContent) { @@ -48,13 +58,19 @@ type ImageContent struct { } func (c *ImageContent) MarshalJSON() ([]byte, error) { - return json.Marshal(&wireContent{ + // Custom wire format to ensure required fields are always included, even when empty. + data := c.Data + if data == nil { + data = []byte{} + } + wire := imageAudioWire{ Type: "image", MIMEType: c.MIMEType, - Data: c.Data, + Data: data, Meta: c.Meta, Annotations: c.Annotations, - }) + } + return json.Marshal(wire) } func (c *ImageContent) fromWire(wire *wireContent) { @@ -73,13 +89,19 @@ type AudioContent struct { } func (c AudioContent) MarshalJSON() ([]byte, error) { - return json.Marshal(&wireContent{ + // Custom wire format to ensure required fields are always included, even when empty. + data := c.Data + if data == nil { + data = []byte{} + } + wire := imageAudioWire{ Type: "audio", MIMEType: c.MIMEType, - Data: c.Data, + Data: data, Meta: c.Meta, Annotations: c.Annotations, - }) + } + return json.Marshal(wire) } func (c *AudioContent) fromWire(wire *wireContent) { @@ -89,6 +111,15 @@ func (c *AudioContent) fromWire(wire *wireContent) { c.Annotations = wire.Annotations } +// Custom wire format to ensure required fields are always included, even when empty. +type imageAudioWire struct { + Type string `json:"type"` + MIMEType string `json:"mimeType"` + Data []byte `json:"data"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` +} + // ResourceLink is a link to a resource type ResourceLink struct { URI string @@ -177,10 +208,12 @@ func (r ResourceContents) MarshalJSON() ([]byte, error) { URI string `json:"uri,omitempty"` MIMEType string `json:"mimeType,omitempty"` Blob []byte `json:"blob"` + Meta Meta `json:"_meta,omitempty"` }{ URI: r.URI, MIMEType: r.MIMEType, Blob: r.Blob, + Meta: r.Meta, } return json.Marshal(br) } @@ -219,6 +252,9 @@ func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, e } func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error) { + if wire == nil { + return nil, fmt.Errorf("nil content") + } if allow != nil && !allow[wire.Type] { return nil, fmt.Errorf("invalid content type %q", wire.Type) } diff --git a/mcp/content_nil_test.go b/mcp/content_nil_test.go new file mode 100644 index 00000000..8cc7bdbd --- /dev/null +++ b/mcp/content_nil_test.go @@ -0,0 +1,226 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file contains tests to verify that UnmarshalJSON methods for Content types +// don't panic when unmarshaling onto nil pointers, as requested in GitHub issue #205. +// +// NOTE: The contentFromWire function has been fixed to handle nil wire.Content +// gracefully by returning an error instead of panicking. + +package mcp_test + +import ( + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func TestContentUnmarshalNil(t *testing.T) { + tests := []struct { + name string + json string + content interface{} + want interface{} + }{ + { + name: "CallToolResult nil Content", + json: `{"content":[{"type":"text","text":"hello"}]}`, + content: &mcp.CallToolResult{}, + want: &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "hello"}}}, + }, + { + name: "CreateMessageResult nil Content", + json: `{"content":{"type":"text","text":"hello"},"model":"test","role":"user"}`, + content: &mcp.CreateMessageResult{}, + want: &mcp.CreateMessageResult{Content: &mcp.TextContent{Text: "hello"}, Model: "test", Role: "user"}, + }, + { + name: "PromptMessage nil Content", + json: `{"content":{"type":"text","text":"hello"},"role":"user"}`, + content: &mcp.PromptMessage{}, + want: &mcp.PromptMessage{Content: &mcp.TextContent{Text: "hello"}, Role: "user"}, + }, + { + name: "SamplingMessage nil Content", + json: `{"content":{"type":"text","text":"hello"},"role":"user"}`, + content: &mcp.SamplingMessage{}, + want: &mcp.SamplingMessage{Content: &mcp.TextContent{Text: "hello"}, Role: "user"}, + }, + { + name: "CallToolResultFor nil Content", + json: `{"content":[{"type":"text","text":"hello"}]}`, + content: &mcp.CallToolResult{}, + want: &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "hello"}}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test that unmarshaling doesn't panic on nil Content fields + defer func() { + if r := recover(); r != nil { + t.Errorf("UnmarshalJSON panicked: %v", r) + } + }() + + err := json.Unmarshal([]byte(tt.json), tt.content) + if err != nil { + t.Errorf("UnmarshalJSON failed: %v", err) + } + + // Verify that the Content field was properly populated + if cmp.Diff(tt.want, tt.content, ctrCmpOpts...) != "" { + t.Errorf("Content is not equal: %v", cmp.Diff(tt.content, tt.content)) + } + }) + } +} + +func TestContentUnmarshalNilWithDifferentTypes(t *testing.T) { + tests := []struct { + name string + json string + content interface{} + expectError bool + }{ + { + name: "ImageContent", + json: `{"content":{"type":"image","mimeType":"image/png","data":"YTFiMmMz"}}`, + content: &mcp.CreateMessageResult{}, + expectError: false, + }, + { + name: "AudioContent", + json: `{"content":{"type":"audio","mimeType":"audio/wav","data":"YTFiMmMz"}}`, + content: &mcp.CreateMessageResult{}, + expectError: false, + }, + { + name: "ResourceLink", + json: `{"content":{"type":"resource_link","uri":"file:///test","name":"test"}}`, + content: &mcp.CreateMessageResult{}, + expectError: true, // CreateMessageResult only allows text, image, audio + }, + { + name: "EmbeddedResource", + json: `{"content":{"type":"resource","resource":{"uri":"file://test","text":"test"}}}`, + content: &mcp.CreateMessageResult{}, + expectError: true, // CreateMessageResult only allows text, image, audio + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test that unmarshaling doesn't panic on nil Content fields + defer func() { + if r := recover(); r != nil { + t.Errorf("UnmarshalJSON panicked: %v", r) + } + }() + + err := json.Unmarshal([]byte(tt.json), tt.content) + if tt.expectError && err == nil { + t.Error("Expected error but got none") + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Verify that the Content field was properly populated for successful cases + if !tt.expectError { + if result, ok := tt.content.(*mcp.CreateMessageResult); ok { + if result.Content == nil { + t.Error("CreateMessageResult.Content was not populated") + } + } + } + }) + } +} + +func TestContentUnmarshalNilWithEmptyContent(t *testing.T) { + tests := []struct { + name string + json string + content interface{} + expectError bool + }{ + { + name: "Empty Content array", + json: `{"content":[]}`, + content: &mcp.CallToolResult{}, + expectError: false, + }, + { + name: "Missing Content field", + json: `{"model":"test","role":"user"}`, + content: &mcp.CreateMessageResult{}, + expectError: true, // Content field is required for CreateMessageResult + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test that unmarshaling doesn't panic on nil Content fields + defer func() { + if r := recover(); r != nil { + t.Errorf("UnmarshalJSON panicked: %v", r) + } + }() + + err := json.Unmarshal([]byte(tt.json), tt.content) + if tt.expectError && err == nil { + t.Error("Expected error but got none") + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + } +} + +func TestContentUnmarshalNilWithInvalidContent(t *testing.T) { + tests := []struct { + name string + json string + content interface{} + expectError bool + }{ + { + name: "Invalid content type", + json: `{"content":{"type":"invalid","text":"hello"}}`, + content: &mcp.CreateMessageResult{}, + expectError: true, + }, + { + name: "Missing type field", + json: `{"content":{"text":"hello"}}`, + content: &mcp.CreateMessageResult{}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test that unmarshaling doesn't panic on nil Content fields + defer func() { + if r := recover(); r != nil { + t.Errorf("UnmarshalJSON panicked: %v", r) + } + }() + + err := json.Unmarshal([]byte(tt.json), tt.content) + if tt.expectError && err == nil { + t.Error("Expected error but got none") + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + } +} + +var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(mcp.CallToolResult{})} diff --git a/mcp/content_test.go b/mcp/content_test.go index 5ee6f66c..9366b0d4 100644 --- a/mcp/content_test.go +++ b/mcp/content_test.go @@ -22,6 +22,14 @@ func TestContent(t *testing.T) { &mcp.TextContent{Text: "hello"}, `{"type":"text","text":"hello"}`, }, + { + &mcp.TextContent{Text: ""}, + `{"type":"text","text":""}`, + }, + { + &mcp.TextContent{}, + `{"type":"text","text":""}`, + }, { &mcp.TextContent{ Text: "hello", @@ -37,6 +45,14 @@ func TestContent(t *testing.T) { }, `{"type":"image","mimeType":"image/png","data":"YTFiMmMz"}`, }, + { + &mcp.ImageContent{MIMEType: "image/png", Data: []byte{}}, + `{"type":"image","mimeType":"image/png","data":""}`, + }, + { + &mcp.ImageContent{Data: []byte("test")}, + `{"type":"image","mimeType":"","data":"dGVzdA=="}`, + }, { &mcp.ImageContent{ Data: []byte("a1b2c3"), @@ -53,6 +69,14 @@ func TestContent(t *testing.T) { }, `{"type":"audio","mimeType":"audio/wav","data":"YTFiMmMz"}`, }, + { + &mcp.AudioContent{MIMEType: "audio/wav", Data: []byte{}}, + `{"type":"audio","mimeType":"audio/wav","data":""}`, + }, + { + &mcp.AudioContent{Data: []byte("test")}, + `{"type":"audio","mimeType":"","data":"dGVzdA=="}`, + }, { &mcp.AudioContent{ Data: []byte("a1b2c3"), @@ -146,6 +170,10 @@ func TestEmbeddedResource(t *testing.T) { &mcp.ResourceContents{URI: "u", Blob: []byte{1}}, `{"uri":"u","blob":"AQ=="}`, }, + { + &mcp.ResourceContents{URI: "u", MIMEType: "m", Blob: []byte{1}, Meta: mcp.Meta{"key": "value"}}, + `{"uri":"u","mimeType":"m","blob":"AQ==","_meta":{"key":"value"}}`, + }, } { data, err := json.Marshal(tt.rc) if err != nil { diff --git a/mcp/create-repo.sh b/mcp/create-repo.sh deleted file mode 100755 index a68964bb..00000000 --- a/mcp/create-repo.sh +++ /dev/null @@ -1,80 +0,0 @@ -#!/bin/bash - -# This script creates an MCP SDK repo from x/tools/internal/mcp (and friends). -# It will be used as a one-off to create github.com/modelcontextprotocol/go-sdk. -# -# Requires https://github.com/newren/git-filter-repo. - -set -eu - -# Check if exactly one argument is provided -if [ "$#" -ne 1 ]; then - echo "create-repo.sh: create a standalone mcp SDK repo from x/tools" - echo "Usage: $0 " - exit 1 -fi >&2 - -src=$(go list -m -f {{.Dir}} golang.org/x/tools) -dest="$1" - -echo "Filtering MCP commits from ${src} to ${dest}..." >&2 - -startdir=$(pwd) -tempdir=$(mktemp -d) -function cleanup { - echo "cleaning up ${tempdir}" - rm -rf "$tempdir" -} >&2 -trap cleanup EXIT SIGINT - -echo "Checking out to ${tempdir}" - -git clone --bare "${src}" "${tempdir}" -git -C "${tempdir}" --git-dir=. filter-repo \ - --path internal/mcp/jsonschema --path-rename internal/mcp/jsonschema:jsonschema \ - --path internal/mcp/design --path-rename internal/mcp/design:design \ - --path internal/mcp/examples --path-rename internal/mcp/examples:examples \ - --path internal/mcp/internal --path-rename internal/mcp/internal:internal \ - --path internal/mcp/README.md --path-rename internal/mcp/README.md:README.md \ - --path internal/mcp/CONTRIBUTING.md --path-rename internal/mcp/CONTRIBUTING.md:CONTRIBUTING.md \ - --path internal/mcp --path-rename internal/mcp:mcp \ - --path internal/jsonrpc2_v2 --path-rename internal/jsonrpc2_v2:internal/jsonrpc2 \ - --path internal/xcontext \ - --replace-text "${startdir}/mcp-repo-replace.txt" \ - --force -mkdir ${dest} -cd "${dest}" -git init - -cat << EOF > LICENSE -MIT License - -Copyright (c) 2025 Go MCP SDK Authors - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -EOF - -git add LICENSE && git commit -m "Initial commit: add LICENSE" -git remote add filtered_source "${tempdir}" -git pull filtered_source master --allow-unrelated-histories -git remote remove filtered_source -go mod init github.com/modelcontextprotocol/go-sdk && go get go@1.23.0 -go mod tidy -git add go.mod go.sum -git commit -m "all: add go.mod and go.sum file" diff --git a/mcp/event.go b/mcp/event.go new file mode 100644 index 00000000..bd78cdee --- /dev/null +++ b/mcp/event.go @@ -0,0 +1,427 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file is for SSE events. +// See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events. + +package mcp + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "iter" + "maps" + "net/http" + "slices" + "strings" + "sync" +) + +// If true, MemoryEventStore will do frequent validation to check invariants, slowing it down. +// Enable for debugging. +const validateMemoryEventStore = false + +// An Event is a server-sent event. +// See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#fields. +type Event struct { + Name string // the "event" field + ID string // the "id" field + Data []byte // the "data" field +} + +// Empty reports whether the Event is empty. +func (e Event) Empty() bool { + return e.Name == "" && e.ID == "" && len(e.Data) == 0 +} + +// writeEvent writes the event to w, and flushes. +func writeEvent(w io.Writer, evt Event) (int, error) { + var b bytes.Buffer + if evt.Name != "" { + fmt.Fprintf(&b, "event: %s\n", evt.Name) + } + if evt.ID != "" { + fmt.Fprintf(&b, "id: %s\n", evt.ID) + } + fmt.Fprintf(&b, "data: %s\n\n", string(evt.Data)) + n, err := w.Write(b.Bytes()) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return n, err +} + +// scanEvents iterates SSE events in the given scanner. The iterated error is +// terminal: if encountered, the stream is corrupt or broken and should no +// longer be used. +// +// TODO(rfindley): consider a different API here that makes failure modes more +// apparent. +func scanEvents(r io.Reader) iter.Seq2[Event, error] { + scanner := bufio.NewScanner(r) + const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size + scanner.Buffer(nil, maxTokenSize) + + // TODO: investigate proper behavior when events are out of order, or have + // non-standard names. + var ( + eventKey = []byte("event") + idKey = []byte("id") + dataKey = []byte("data") + ) + + return func(yield func(Event, error) bool) { + // iterate event from the wire. + // https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#examples + // + // - `key: value` line records. + // - Consecutive `data: ...` fields are joined with newlines. + // - Unrecognized fields are ignored. Since we only care about 'event', 'id', and + // 'data', these are the only three we consider. + // - Lines starting with ":" are ignored. + // - Records are terminated with two consecutive newlines. + var ( + evt Event + dataBuf *bytes.Buffer // if non-nil, preceding field was also data + ) + flushData := func() { + if dataBuf != nil { + evt.Data = dataBuf.Bytes() + dataBuf = nil + } + } + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + flushData() + // \n\n is the record delimiter + if !evt.Empty() && !yield(evt, nil) { + return + } + evt = Event{} + continue + } + before, after, found := bytes.Cut(line, []byte{':'}) + if !found { + yield(Event{}, fmt.Errorf("malformed line in SSE stream: %q", string(line))) + return + } + if !bytes.Equal(before, dataKey) { + flushData() + } + switch { + case bytes.Equal(before, eventKey): + evt.Name = strings.TrimSpace(string(after)) + case bytes.Equal(before, idKey): + evt.ID = strings.TrimSpace(string(after)) + case bytes.Equal(before, dataKey): + data := bytes.TrimSpace(after) + if dataBuf != nil { + dataBuf.WriteByte('\n') + dataBuf.Write(data) + } else { + dataBuf = new(bytes.Buffer) + dataBuf.Write(data) + } + } + } + if err := scanner.Err(); err != nil { + if errors.Is(err, bufio.ErrTooLong) { + err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize) + } + if !yield(Event{}, err) { + return + } + } + flushData() + if !evt.Empty() { + yield(evt, nil) + } + } +} + +// An EventStore tracks data for SSE streams. +// A single EventStore suffices for all sessions, since session IDs are +// globally unique. So one EventStore can be created per process, for +// all Servers in the process. +// Such a store is able to bound resource usage for the entire process. +// +// All of an EventStore's methods must be safe for use by multiple goroutines. +type EventStore interface { + // Open prepares the event store for a given stream. It ensures that the + // underlying data structure for the stream is initialized, making it + // ready to store event streams. + // + // streamIDs must be globally unique. + Open(_ context.Context, sessionID, streamID string) error + + // Append appends data for an outgoing event to given stream, which is part of the + // given session. + Append(_ context.Context, sessionID, streamID string, data []byte) error + + // After returns an iterator over the data for the given session and stream, beginning + // just after the given index. + // Once the iterator yields a non-nil error, it will stop. + // After's iterator must return an error immediately if any data after index was + // dropped; it must not return partial results. + // The stream must have been opened previously (see [EventStore.Open]). + After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] + + // SessionClosed informs the store that the given session is finished, along + // with all of its streams. + // A store cannot rely on this method being called for cleanup. It should institute + // additional mechanisms, such as timeouts, to reclaim storage. + // + SessionClosed(_ context.Context, sessionID string) error + + // There is no StreamClosed method. A server doesn't know when a stream is finished, because + // the client can always send a GET with a Last-Event-ID referring to the stream. +} + +// A dataList is a list of []byte. +// The zero dataList is ready to use. +type dataList struct { + size int // total size of data bytes + first int // the stream index of the first element in data + data [][]byte +} + +func (dl *dataList) appendData(d []byte) { + // If we allowed empty data, we would consume memory without incrementing the size. + // We could of course account for that, but we keep it simple and assume there is no + // empty data. + if len(d) == 0 { + panic("empty data item") + } + dl.data = append(dl.data, d) + dl.size += len(d) +} + +// removeFirst removes the first data item in dl, returning the size of the item. +// It panics if dl is empty. +func (dl *dataList) removeFirst() int { + if len(dl.data) == 0 { + panic("empty dataList") + } + r := len(dl.data[0]) + dl.size -= r + dl.data[0] = nil // help GC + dl.data = dl.data[1:] + dl.first++ + return r +} + +// A MemoryEventStore is an [EventStore] backed by memory. +type MemoryEventStore struct { + mu sync.Mutex + maxBytes int // max total size of all data + nBytes int // current total size of all data + store map[string]map[string]*dataList // session ID -> stream ID -> *dataList +} + +// MemoryEventStoreOptions are options for a [MemoryEventStore]. +type MemoryEventStoreOptions struct{} + +// MaxBytes returns the maximum number of bytes that the store will retain before +// purging data. +func (s *MemoryEventStore) MaxBytes() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.maxBytes +} + +// SetMaxBytes sets the maximum number of bytes the store will retain before purging +// data. The argument must not be negative. If it is zero, a suitable default will be used. +// SetMaxBytes can be called at any time. The size of the store will be adjusted +// immediately. +func (s *MemoryEventStore) SetMaxBytes(n int) { + s.mu.Lock() + defer s.mu.Unlock() + switch { + case n < 0: + panic("negative argument") + case n == 0: + s.maxBytes = defaultMaxBytes + default: + s.maxBytes = n + } + s.purge() +} + +const defaultMaxBytes = 10 << 20 // 10 MiB + +// NewMemoryEventStore creates a [MemoryEventStore] with the default value +// for MaxBytes. +func NewMemoryEventStore(opts *MemoryEventStoreOptions) *MemoryEventStore { + return &MemoryEventStore{ + maxBytes: defaultMaxBytes, + store: make(map[string]map[string]*dataList), + } +} + +// Open implements [EventStore.Open]. It ensures that the underlying data +// structures for the given session are initialized and ready for use. +func (s *MemoryEventStore) Open(_ context.Context, sessionID, streamID string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.init(sessionID, streamID) + return nil +} + +// init is an internal helper function that ensures the nested map structure for a +// given sessionID and streamID exists, creating it if necessary. It returns the +// dataList associated with the specified IDs. +// Requires s.mu. +func (s *MemoryEventStore) init(sessionID, streamID string) *dataList { + streamMap, ok := s.store[sessionID] + if !ok { + streamMap = make(map[string]*dataList) + s.store[sessionID] = streamMap + } + dl, ok := streamMap[streamID] + if !ok { + dl = &dataList{} + streamMap[streamID] = dl + } + return dl +} + +// Append implements [EventStore.Append] by recording data in memory. +func (s *MemoryEventStore) Append(_ context.Context, sessionID, streamID string, data []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + dl := s.init(sessionID, streamID) + // Purge before adding, so at least the current data item will be present. + // (That could result in nBytes > maxBytes, but we'll live with that.) + s.purge() + dl.appendData(data) + s.nBytes += len(data) + return nil +} + +// ErrEventsPurged is the error that [EventStore.After] should return if the event just after the +// index is no longer available. +var ErrEventsPurged = errors.New("data purged") + +// After implements [EventStore.After]. +func (s *MemoryEventStore) After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] { + // Return the data items to yield. + // We must copy, because dataList.removeFirst nils out slice elements. + copyData := func() ([][]byte, error) { + s.mu.Lock() + defer s.mu.Unlock() + streamMap, ok := s.store[sessionID] + if !ok { + return nil, fmt.Errorf("MemoryEventStore.After: unknown session ID %q", sessionID) + } + dl, ok := streamMap[streamID] + if !ok { + return nil, fmt.Errorf("MemoryEventStore.After: unknown stream ID %v in session %q", streamID, sessionID) + } + start := index + 1 + if dl.first > start { + return nil, fmt.Errorf("MemoryEventStore.After: index %d, stream ID %v, session %q: %w", + index, streamID, sessionID, ErrEventsPurged) + } + return slices.Clone(dl.data[start-dl.first:]), nil + } + + return func(yield func([]byte, error) bool) { + ds, err := copyData() + if err != nil { + yield(nil, err) + return + } + for _, d := range ds { + if !yield(d, nil) { + return + } + } + } +} + +// SessionClosed implements [EventStore.SessionClosed]. +func (s *MemoryEventStore) SessionClosed(_ context.Context, sessionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + for _, dl := range s.store[sessionID] { + s.nBytes -= dl.size + } + delete(s.store, sessionID) + s.validate() + return nil +} + +// purge removes data until no more than s.maxBytes bytes are in use. +// It must be called with s.mu held. +func (s *MemoryEventStore) purge() { + // Remove the first element of every dataList until below the max. + for s.nBytes > s.maxBytes { + changed := false + for _, sm := range s.store { + for _, dl := range sm { + if dl.size > 0 { + r := dl.removeFirst() + if r > 0 { + changed = true + s.nBytes -= r + } + } + } + } + if !changed { + panic("no progress during purge") + } + } + s.validate() +} + +// validate checks that the store's data structures are valid. +// It must be called with s.mu held. +func (s *MemoryEventStore) validate() { + if !validateMemoryEventStore { + return + } + // Check that we're accounting for the size correctly. + n := 0 + for _, sm := range s.store { + for _, dl := range sm { + for _, d := range dl.data { + n += len(d) + } + } + } + if n != s.nBytes { + panic("sizes don't add up") + } +} + +// debugString returns a string containing the state of s. +// Used in tests. +func (s *MemoryEventStore) debugString() string { + s.mu.Lock() + defer s.mu.Unlock() + var b strings.Builder + for i, sess := range slices.Sorted(maps.Keys(s.store)) { + if i > 0 { + fmt.Fprintf(&b, "; ") + } + sm := s.store[sess] + for i, sid := range slices.Sorted(maps.Keys(sm)) { + if i > 0 { + fmt.Fprintf(&b, "; ") + } + dl := sm[sid] + fmt.Fprintf(&b, "%s %s first=%d", sess, sid, dl.first) + for _, d := range dl.data { + fmt.Fprintf(&b, " %s", d) + } + } + } + return b.String() +} diff --git a/mcp/event_test.go b/mcp/event_test.go new file mode 100644 index 00000000..dacf30e8 --- /dev/null +++ b/mcp/event_test.go @@ -0,0 +1,298 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "slices" + "strings" + "testing" + "time" +) + +func TestScanEvents(t *testing.T) { + tests := []struct { + name string + input string + want []Event + wantErr string + }{ + { + name: "simple event", + input: "event: message\nid: 1\ndata: hello\n\n", + want: []Event{ + {Name: "message", ID: "1", Data: []byte("hello")}, + }, + }, + { + name: "multiple data lines", + input: "data: line 1\ndata: line 2\n\n", + want: []Event{ + {Data: []byte("line 1\nline 2")}, + }, + }, + { + name: "multiple events", + input: "data: first\n\nevent: second\ndata: second\n\n", + want: []Event{ + {Data: []byte("first")}, + {Name: "second", Data: []byte("second")}, + }, + }, + { + name: "no trailing newline", + input: "data: hello", + want: []Event{ + {Data: []byte("hello")}, + }, + }, + { + name: "malformed line", + input: "invalid line\n\n", + wantErr: "malformed line", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := strings.NewReader(tt.input) + var got []Event + var err error + for e, err2 := range scanEvents(r) { + if err2 != nil { + err = err2 + break + } + got = append(got, e) + } + + if tt.wantErr != "" { + if err == nil { + t.Fatalf("scanEvents() got nil error, want error containing %q", tt.wantErr) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("scanEvents() error = %q, want containing %q", err, tt.wantErr) + } + return + } + + if err != nil { + t.Fatalf("scanEvents() returned unexpected error: %v", err) + } + + if len(got) != len(tt.want) { + t.Fatalf("scanEvents() got %d events, want %d", len(got), len(tt.want)) + } + + for i := range got { + if g, w := got[i].Name, tt.want[i].Name; g != w { + t.Errorf("event %d: name = %q, want %q", i, g, w) + } + if g, w := got[i].ID, tt.want[i].ID; g != w { + t.Errorf("event %d: id = %q, want %q", i, g, w) + } + if g, w := string(got[i].Data), string(tt.want[i].Data); g != w { + t.Errorf("event %d: data = %q, want %q", i, g, w) + } + } + }) + } +} + +func TestMemoryEventStoreState(t *testing.T) { + ctx := context.Background() + + appendEvent := func(s *MemoryEventStore, sess, stream string, data string) { + if err := s.Append(ctx, sess, stream, []byte(data)); err != nil { + t.Fatal(err) + } + } + + for _, tt := range []struct { + name string + actions func(*MemoryEventStore) + want string // output of debugString + wantSize int // value of nBytes + }{ + { + "appends", + func(s *MemoryEventStore) { + appendEvent(s, "S1", "1", "d1") + appendEvent(s, "S1", "2", "d2") + appendEvent(s, "S1", "1", "d3") + appendEvent(s, "S2", "8", "d4") + }, + "S1 1 first=0 d1 d3; S1 2 first=0 d2; S2 8 first=0 d4", + 8, + }, + { + "session close", + func(s *MemoryEventStore) { + appendEvent(s, "S1", "1", "d1") + appendEvent(s, "S1", "2", "d2") + appendEvent(s, "S1", "1", "d3") + appendEvent(s, "S2", "8", "d4") + s.SessionClosed(ctx, "S1") + }, + "S2 8 first=0 d4", + 2, + }, + { + "purge", + func(s *MemoryEventStore) { + appendEvent(s, "S1", "1", "d1") + appendEvent(s, "S1", "2", "d2") + appendEvent(s, "S1", "1", "d3") + appendEvent(s, "S2", "8", "d4") + // We are using 8 bytes (d1,d2, d3, d4). + // To purge 6, we remove the first of each stream, leaving only d3. + s.SetMaxBytes(2) + }, + // The other streams remain, because we may add to them. + "S1 1 first=1 d3; S1 2 first=1; S2 8 first=1", + 2, + }, + { + "purge append", + func(s *MemoryEventStore) { + appendEvent(s, "S1", "1", "d1") + appendEvent(s, "S1", "2", "d2") + appendEvent(s, "S1", "1", "d3") + appendEvent(s, "S2", "8", "d4") + s.SetMaxBytes(2) + // Up to here, identical to the "purge" case. + // Each of these additions will result in a purge. + appendEvent(s, "S1", "2", "d5") // remove d3 + appendEvent(s, "S1", "2", "d6") // remove d5 + }, + "S1 1 first=2; S1 2 first=2 d6; S2 8 first=1", + 2, + }, + { + "purge resize append", + func(s *MemoryEventStore) { + appendEvent(s, "S1", "1", "d1") + appendEvent(s, "S1", "2", "d2") + appendEvent(s, "S1", "1", "d3") + appendEvent(s, "S2", "8", "d4") + s.SetMaxBytes(2) + // Up to here, identical to the "purge" case. + s.SetMaxBytes(6) // make room + appendEvent(s, "S1", "2", "d5") + appendEvent(s, "S1", "2", "d6") + }, + // The other streams remain, because we may add to them. + "S1 1 first=1 d3; S1 2 first=1 d5 d6; S2 8 first=1", + 6, + }, + } { + t.Run(tt.name, func(t *testing.T) { + s := NewMemoryEventStore(nil) + tt.actions(s) + got := s.debugString() + if got != tt.want { + t.Errorf("\ngot %s\nwant %s", got, tt.want) + } + if g, w := s.nBytes, tt.wantSize; g != w { + t.Errorf("got size %d, want %d", g, w) + } + }) + } +} + +func TestMemoryEventStoreAfter(t *testing.T) { + ctx := context.Background() + s := NewMemoryEventStore(nil) + s.SetMaxBytes(4) + s.Append(ctx, "S1", "1", []byte("d1")) + s.Append(ctx, "S1", "1", []byte("d2")) + s.Append(ctx, "S1", "1", []byte("d3")) + s.Append(ctx, "S1", "2", []byte("d4")) // will purge d1 + want := "S1 1 first=1 d2 d3; S1 2 first=0 d4" + if got := s.debugString(); got != want { + t.Fatalf("got state %q, want %q", got, want) + } + + for _, tt := range []struct { + sessionID string + streamID string + index int + want []string + wantErr string // if non-empty, error should contain this string + }{ + {"S1", "1", 0, []string{"d2", "d3"}, ""}, + {"S1", "1", 1, []string{"d3"}, ""}, + {"S1", "1", 2, nil, ""}, + {"S1", "2", 0, nil, ""}, + {"S1", "3", 0, nil, "unknown stream ID"}, + {"S2", "0", 0, nil, "unknown session ID"}, + } { + t.Run(fmt.Sprintf("%s-%s-%d", tt.sessionID, tt.streamID, tt.index), func(t *testing.T) { + var got []string + for d, err := range s.After(ctx, tt.sessionID, tt.streamID, tt.index) { + if err != nil { + if tt.wantErr == "" { + t.Fatalf("unexpected error %q", err) + } else if g := err.Error(); !strings.Contains(g, tt.wantErr) { + t.Fatalf("got error %q, want it to contain %q", g, tt.wantErr) + } else { + return + } + } + got = append(got, string(d)) + } + if tt.wantErr != "" { + t.Fatalf("expected error containing %q, got nil", tt.wantErr) + } + if !slices.Equal(got, tt.want) { + t.Errorf("got %v, want %v", got, tt.want) + } + }) + } +} + +func BenchmarkMemoryEventStore(b *testing.B) { + // Benchmark with various settings for event store size, number of session, + // and payload size. + // + // Assume a small number of streams per session, which is probably realistic. + tests := []struct { + name string + limit int + sessions int + datasize int + }{ + {"1KB", 1024, 1, 16}, + {"1MB", 1024 * 1024, 10, 16}, + {"10MB", 10 * 1024 * 1024, 100, 16}, + {"10MB_big", 10 * 1024 * 1024, 1000, 128}, + } + + for _, test := range tests { + b.Run(test.name, func(b *testing.B) { + store := NewMemoryEventStore(nil) + store.SetMaxBytes(test.limit) + ctx := context.Background() + sessionIDs := make([]string, test.sessions) + streamIDs := make([][3]string, test.sessions) + for i := range sessionIDs { + sessionIDs[i] = fmt.Sprint(i) + for j := range 3 { + streamIDs[i][j] = randText() + } + } + payload := make([]byte, test.datasize) + start := time.Now() + b.ResetTimer() + for i := range b.N { + sessionID := sessionIDs[i%len(sessionIDs)] + streamID := streamIDs[i%len(sessionIDs)][i%3] + store.Append(ctx, sessionID, streamID, payload) + } + b.ReportMetric(float64(test.datasize)*float64(b.N)/time.Since(start).Seconds(), "bytes/s") + }) + } +} diff --git a/mcp/features.go b/mcp/features.go index 1777b33f..438370fe 100644 --- a/mcp/features.go +++ b/mcp/features.go @@ -17,7 +17,9 @@ import ( // A featureSet is a collection of features of type T. // Every feature has a unique ID, and the spec never mentions // an ordering for the List calls, so what it calls a "list" is actually a set. -// TODO: switch to an ordered map +// +// An alternative implementation would use an ordered map, but that's probably +// not necessary as adds and removes are rare, and usually batched. type featureSet[T any] struct { uniqueID func(T) string features map[string]T @@ -66,6 +68,9 @@ func (s *featureSet[T]) get(uid string) (T, bool) { return t, ok } +// len returns the number of features in the set. +func (s *featureSet[T]) len() int { return len(s.features) } + // all returns an iterator over of all the features in the set // sorted by unique ID. func (s *featureSet[T]) all() iter.Seq[T] { diff --git a/mcp/features_test.go b/mcp/features_test.go index 5ffbce8c..6df9b16e 100644 --- a/mcp/features_test.go +++ b/mcp/features_test.go @@ -5,31 +5,22 @@ package mcp import ( - "context" "slices" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/google/jsonschema-go/jsonschema" ) type SayHiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[SayHiParams]) (*CallToolResultFor[any], error) { - return &CallToolResultFor[any]{ - Content: []Content{ - &TextContent{Text: "Hi " + params.Name}, - }, - }, nil -} - func TestFeatureSetOrder(t *testing.T) { - toolA := NewServerTool("apple", "apple tool", SayHi).Tool - toolB := NewServerTool("banana", "banana tool", SayHi).Tool - toolC := NewServerTool("cherry", "cherry tool", SayHi).Tool + toolA := &Tool{Name: "apple", Description: "apple tool"} + toolB := &Tool{Name: "banana", Description: "banana tool"} + toolC := &Tool{Name: "cherry", Description: "cherry tool"} testCases := []struct { tools []*Tool @@ -52,9 +43,9 @@ func TestFeatureSetOrder(t *testing.T) { } func TestFeatureSetAbove(t *testing.T) { - toolA := NewServerTool("apple", "apple tool", SayHi).Tool - toolB := NewServerTool("banana", "banana tool", SayHi).Tool - toolC := NewServerTool("cherry", "cherry tool", SayHi).Tool + toolA := &Tool{Name: "apple", Description: "apple tool"} + toolB := &Tool{Name: "banana", Description: "banana tool"} + toolC := &Tool{Name: "cherry", Description: "cherry tool"} testCases := []struct { tools []*Tool diff --git a/mcp/logging.go b/mcp/logging.go index 4880e179..b3186a96 100644 --- a/mcp/logging.go +++ b/mcp/logging.go @@ -70,6 +70,7 @@ type LoggingHandlerOptions struct { // The value for the "logger" field of logging notifications. LoggerName string // Limits the rate at which log messages are sent. + // Excess messages are dropped. // If zero, there is no rate limiting. MinInterval time.Duration } @@ -117,7 +118,7 @@ func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool { // This is also checked in ServerSession.LoggingMessage, so checking it here // is just an optimization that skips building the JSON. h.ss.mu.Lock() - mcpLevel := h.ss.logLevel + mcpLevel := h.ss.state.LogLevel h.ss.mu.Unlock() return level >= mcpLevelToSlog(mcpLevel) } diff --git a/mcp/mcp-repo-replace.txt b/mcp/mcp-repo-replace.txt deleted file mode 100644 index 3409dd7a..00000000 --- a/mcp/mcp-repo-replace.txt +++ /dev/null @@ -1,9 +0,0 @@ -"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"==>"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" -github.com/modelcontextprotocol/go-sdk/internal/xcontext==>github.com/modelcontextprotocol/go-sdk/internal/xcontext -github.com/modelcontextprotocol/go-sdk/internal==>github.com/modelcontextprotocol/go-sdk/internal -github.com/modelcontextprotocol/go-sdk/jsonschema==>github.com/modelcontextprotocol/go-sdk/jsonschema -github.com/modelcontextprotocol/go-sdk/examples==>github.com/modelcontextprotocol/go-sdk/examples -github.com/modelcontextprotocol/go-sdk/design==>github.com/modelcontextprotocol/go-sdk/design -github.com/modelcontextprotocol/go-sdk/mcp==>github.com/modelcontextprotocol/go-sdk/mcp -governed by an MIT-style==>governed by an MIT-style -regex:Copyright (20\d\d) The Go Authors==>Copyright \1 The Go MCP SDK Authors diff --git a/mcp/mcp.go b/mcp/mcp.go index a22748c9..88321a1e 100644 --- a/mcp/mcp.go +++ b/mcp/mcp.go @@ -2,37 +2,87 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. -//go:generate ../internal/readme/build.sh - // The mcp package provides an SDK for writing model context protocol clients // and servers. // -// To get started, create either a [Client] or [Server], and connect it to a -// peer using a [Transport]. The diagram below illustrates how this works: +// To get started, create either a [Client] or [Server], add features to it +// using `AddXXX` functions, and connect it to a peer using a [Transport]. +// +// For example, to run a simple server on the [StdioTransport]: +// +// server := mcp.NewServer(&mcp.Implementation{Name: "greeter"}, nil) +// +// // Using the generic AddTool automatically populates the the input and output +// // schema of the tool. +// type args struct { +// Name string `json:"name" jsonschema:"the person to greet"` +// } +// mcp.AddTool(server, &mcp.Tool{ +// Name: "greet", +// Description: "say hi", +// }, func(ctx context.Context, req *mcp.CallToolRequest, args args) (*mcp.CallToolResult, any, error) { +// return &mcp.CallToolResult{ +// Content: []mcp.Content{ +// &mcp.TextContent{Text: "Hi " + args.Name}, +// }, +// }, nil, nil +// }) +// +// // Run the server on the stdio transport. +// if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { +// log.Printf("Server failed: %v", err) +// } +// +// To connect to this server, use the [CommandTransport]: +// +// client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) +// transport := &mcp.CommandTransport{Command: exec.Command("myserver")} +// session, err := client.Connect(ctx, transport, nil) +// if err != nil { +// log.Fatal(err) +// } +// defer session.Close() +// +// params := &mcp.CallToolParams{ +// Name: "greet", +// Arguments: map[string]any{"name": "you"}, +// } +// res, err := session.CallTool(ctx, params) +// if err != nil { +// log.Fatalf("CallTool failed: %v", err) +// } +// +// # Clients, servers, and sessions +// +// In this SDK, both a [Client] and [Server] may handle many concurrent +// connections. Each time a client or server is connected to a peer using a +// [Transport], it creates a new session (either a [ClientSession] or +// [ServerSession]): // // Client Server // ⇅ (jsonrpc2) ⇅ // ClientSession ⇄ Client Transport ⇄ Server Transport ⇄ ServerSession // -// A [Client] is an MCP client, which can be configured with various client -// capabilities. Clients may be connected to a [Server] instance -// using the [Client.Connect] method. -// -// Similarly, a [Server] is an MCP server, which can be configured with various -// server capabilities. Servers may be connected to one or more [Client] -// instances using the [Server.Connect] method, which creates a -// [ServerSession]. -// -// A [Transport] connects a bidirectional [Connection] of jsonrpc2 messages. In -// practice, transports in the MCP spec are are either client transports or -// server transports. For example, the [StdioTransport] is a server transport -// that communicates over stdin/stdout, and its counterpart is a -// [CommandTransport] that communicates with a subprocess over its -// stdin/stdout. -// -// Some transports may hide more complicated details, such as an -// [SSEClientTransport], which reads messages via server-sent events on a -// hanging GET request, and writes them to a POST endpoint. Users of this SDK -// may define their own custom Transports by implementing the [Transport] -// interface. +// The session types expose an API to interact with its peer. For example, +// [ClientSession.CallTool] or [ServerSession.ListRoots]. +// +// # Adding features +// +// Add MCP servers to your Client or Server using AddXXX methods (for example +// [Client.AddRoot] or [Server.AddPrompt]). If any peers are connected when +// AddXXX is called, they will receive a corresponding change notification +// (for example notifications/roots/list_changed). +// +// Adding tools is special: tools may be bound to ordinary Go functions by +// using the top-level generic [AddTool] function, which allows specifying an +// input and output type. When AddTool is used, the tool's input schema and +// output schema are automatically populated, and inputs are automatically +// validated. As a special case, if the output type is 'any', no output schema +// is generated. +// +// func double(_ context.Context, _ *mcp.CallToolRequest, in In) (*mcp.CallToolResponse, Out, error) { +// return nil, Out{Answer: 2*in.Number}, nil +// } +// ... +// mcp.AddTool(&mcp.Tool{Name: "double", Description: "double a number"}, double) package mcp diff --git a/mcp/mcp_example_test.go b/mcp/mcp_example_test.go new file mode 100644 index 00000000..25f39fb8 --- /dev/null +++ b/mcp/mcp_example_test.go @@ -0,0 +1,166 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// !+lifecycle + +func Example_lifecycle() { + ctx := context.Background() + + // Create a client and server. + // Wait for the client to initialize the session. + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, &mcp.ServerOptions{ + InitializedHandler: func(context.Context, *mcp.InitializedRequest) { + fmt.Println("initialized!") + }, + }) + + // Connect the server and client using in-memory transports. + // + // Connect the server first so that it's ready to receive initialization + // messages from the client. + t1, t2 := mcp.NewInMemoryTransports() + serverSession, err := server.Connect(ctx, t1, nil) + if err != nil { + log.Fatal(err) + } + clientSession, err := client.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + + // Now shut down the session by closing the client, and waiting for the + // server session to end. + if err := clientSession.Close(); err != nil { + log.Fatal(err) + } + if err := serverSession.Wait(); err != nil { + log.Fatal(err) + } + // Output: initialized! +} + +// !-lifecycle + +// !+progress + +func Example_progress() { + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "makeProgress"}, func(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + if token := req.Params.GetProgressToken(); token != nil { + for i := range 3 { + params := &mcp.ProgressNotificationParams{ + Message: "frobbing widgets", + ProgressToken: token, + Progress: float64(i), + Total: 2, + } + req.Session.NotifyProgress(ctx, params) // ignore error + } + } + return &mcp.CallToolResult{}, nil, nil + }) + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, &mcp.ClientOptions{ + ProgressNotificationHandler: func(_ context.Context, req *mcp.ProgressNotificationClientRequest) { + fmt.Printf("%s %.0f/%.0f\n", req.Params.Message, req.Params.Progress, req.Params.Total) + }, + }) + ctx := context.Background() + t1, t2 := mcp.NewInMemoryTransports() + if _, err := server.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + + session, err := client.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer session.Close() + if _, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "makeProgress", + Meta: mcp.Meta{"progressToken": "abc123"}, + }); err != nil { + log.Fatal(err) + } + // Output: + // frobbing widgets 0/2 + // frobbing widgets 1/2 + // frobbing widgets 2/2 +} + +// !-progress + +// !+cancellation + +func Example_cancellation() { + // For this example, we're going to be collecting observations from the + // server and client. + var clientResult, serverResult string + var wg sync.WaitGroup + wg.Add(2) + + // Create a server with a single slow tool. + // When the client cancels its request, the server should observe + // cancellation. + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + started := make(chan struct{}, 1) // signals that the server started handling the tool call + mcp.AddTool(server, &mcp.Tool{Name: "slow"}, func(ctx context.Context, req *mcp.CallToolRequest, _ any) (*mcp.CallToolResult, any, error) { + started <- struct{}{} + defer wg.Done() + select { + case <-time.After(5 * time.Second): + serverResult = "tool done" + case <-ctx.Done(): + serverResult = "tool canceled" + } + return &mcp.CallToolResult{}, nil, nil + }) + + // Connect a client to the server. + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + ctx := context.Background() + t1, t2 := mcp.NewInMemoryTransports() + if _, err := server.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + session, err := client.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer session.Close() + + // Make a tool call, asynchronously. + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer wg.Done() + _, err = session.CallTool(ctx, &mcp.CallToolParams{Name: "slow"}) + clientResult = fmt.Sprintf("%v", err) + }() + + // As soon as the server has started handling the call, cancel it from the + // client side. + <-started + cancel() + wg.Wait() + + fmt.Println(clientResult) + fmt.Println(serverResult) + // Output: + // context canceled + // tool canceled +} + +// !-cancellation diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 5f42b1b9..9cc105cd 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -7,6 +7,7 @@ package mcp import ( "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -17,24 +18,42 @@ import ( "slices" "strings" "sync" + "sync/atomic" "testing" "time" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" - "github.com/modelcontextprotocol/go-sdk/jsonschema" ) type hiParams struct { Name string } -func sayHi(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[hiParams]) (*CallToolResultFor[any], error) { - if err := ss.Ping(ctx, nil); err != nil { - return nil, fmt.Errorf("ping failed: %v", err) +// TODO(jba): after schemas are stateless (WIP), this can be a variable. +func greetTool() *Tool { return &Tool{Name: "greet", Description: "say hi"} } + +func sayHi(ctx context.Context, req *CallToolRequest, args hiParams) (*CallToolResult, any, error) { + if err := req.Session.Ping(ctx, nil); err != nil { + return nil, nil, fmt.Errorf("ping failed: %v", err) } - return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + params.Arguments.Name}}}, nil + return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + args.Name}}}, nil, nil +} + +var codeReviewPrompt = &Prompt{ + Name: "code_review", + Description: "do a code review", + Arguments: []*PromptArgument{{Name: "Code", Required: true}}, +} + +func codReviewPromptHandler(_ context.Context, req *GetPromptRequest) (*GetPromptResult, error) { + return &GetPromptResult{ + Description: "Code review prompt", + Messages: []*PromptMessage{ + {Role: "user", Content: &TextContent{Text: "Please review the following code: " + req.Params.Arguments["Code"]}}, + }, + }, nil } func TestEndToEnd(t *testing.T) { @@ -43,7 +62,7 @@ func TestEndToEnd(t *testing.T) { // Channels to check if notification callbacks happened. notificationChans := map[string]chan int{} - for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client"} { + for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client", "resource_updated", "subscribe", "unsubscribe"} { notificationChans[name] = make(chan int, 1) } waitForNotification := func(t *testing.T, name string) { @@ -56,19 +75,42 @@ func TestEndToEnd(t *testing.T) { } sopts := &ServerOptions{ - InitializedHandler: func(context.Context, *ServerSession, *InitializedParams) { notificationChans["initialized"] <- 0 }, - RootsListChangedHandler: func(context.Context, *ServerSession, *RootsListChangedParams) { notificationChans["roots"] <- 0 }, - ProgressNotificationHandler: func(context.Context, *ServerSession, *ProgressNotificationParams) { + InitializedHandler: func(context.Context, *InitializedRequest) { + notificationChans["initialized"] <- 0 + }, + RootsListChangedHandler: func(context.Context, *RootsListChangedRequest) { + notificationChans["roots"] <- 0 + }, + ProgressNotificationHandler: func(context.Context, *ProgressNotificationServerRequest) { notificationChans["progress_server"] <- 0 }, + SubscribeHandler: func(context.Context, *SubscribeRequest) error { + notificationChans["subscribe"] <- 0 + return nil + }, + UnsubscribeHandler: func(context.Context, *UnsubscribeRequest) error { + notificationChans["unsubscribe"] <- 0 + return nil + }, } - s := NewServer("testServer", "v1.0.0", sopts) - add(tools, s.AddTools, "greet", "fail") - add(prompts, s.AddPrompts, "code_review", "fail") - add(resources, s.AddResources, "info.txt", "fail.txt") + s := NewServer(testImpl, sopts) + AddTool(s, &Tool{ + Name: "greet", + Description: "say hi", + }, sayHi) + AddTool(s, &Tool{Name: "fail", InputSchema: &jsonschema.Schema{Type: "object"}}, + func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) { + return nil, nil, errTestFailure + }) + s.AddPrompt(codeReviewPrompt, codReviewPromptHandler) + s.AddPrompt(&Prompt{Name: "fail"}, func(_ context.Context, _ *GetPromptRequest) (*GetPromptResult, error) { + return nil, errTestFailure + }) + s.AddResource(resource1, readHandler) + s.AddResource(resource2, readHandler) // Connect the server. - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -88,20 +130,38 @@ func TestEndToEnd(t *testing.T) { loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging opts := &ClientOptions{ - CreateMessageHandler: func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) { - return &CreateMessageResult{Model: "aModel"}, nil + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil + }, + ElicitationHandler: func(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{ + Action: "accept", + Content: map[string]any{ + "name": "Test User", + "email": "test@example.com", + }, + }, nil }, - ToolListChangedHandler: func(context.Context, *ClientSession, *ToolListChangedParams) { notificationChans["tools"] <- 0 }, - PromptListChangedHandler: func(context.Context, *ClientSession, *PromptListChangedParams) { notificationChans["prompts"] <- 0 }, - ResourceListChangedHandler: func(context.Context, *ClientSession, *ResourceListChangedParams) { notificationChans["resources"] <- 0 }, - LoggingMessageHandler: func(_ context.Context, _ *ClientSession, lm *LoggingMessageParams) { - loggingMessages <- lm + ToolListChangedHandler: func(context.Context, *ToolListChangedRequest) { + notificationChans["tools"] <- 0 }, - ProgressNotificationHandler: func(context.Context, *ClientSession, *ProgressNotificationParams) { + PromptListChangedHandler: func(context.Context, *PromptListChangedRequest) { + notificationChans["prompts"] <- 0 + }, + ResourceListChangedHandler: func(context.Context, *ResourceListChangedRequest) { + notificationChans["resources"] <- 0 + }, + LoggingMessageHandler: func(_ context.Context, req *LoggingMessageRequest) { + loggingMessages <- req.Params + }, + ProgressNotificationHandler: func(context.Context, *ProgressNotificationClientRequest) { notificationChans["progress_client"] <- 0 }, + ResourceUpdatedHandler: func(context.Context, *ResourceUpdatedNotificationRequest) { + notificationChans["resource_updated"] <- 0 + }, } - c := NewClient("testClient", "v1.0.0", opts) + c := NewClient(testImpl, opts) rootAbs, err := filepath.Abs(filepath.FromSlash("testdata/files")) if err != nil { t.Fatal(err) @@ -109,7 +169,7 @@ func TestEndToEnd(t *testing.T) { c.AddRoots(&Root{URI: "file://" + rootAbs}) // Connect the client. - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -154,42 +214,17 @@ func TestEndToEnd(t *testing.T) { t.Errorf("fail returned unexpected error: got %v, want containing %v", err, errTestFailure) } - s.AddPrompts(&ServerPrompt{Prompt: &Prompt{Name: "T"}}) + s.AddPrompt(&Prompt{Name: "T"}, nil) waitForNotification(t, "prompts") s.RemovePrompts("T") waitForNotification(t, "prompts") }) t.Run("tools", func(t *testing.T) { - res, err := cs.ListTools(ctx, nil) - if err != nil { - t.Errorf("tools/list failed: %v", err) - } - wantTools := []*Tool{ - { - Name: "fail", - InputSchema: nil, - }, - { - Name: "greet", - Description: "say hi", - InputSchema: &jsonschema.Schema{ - Type: "object", - Required: []string{"Name"}, - Properties: map[string]*jsonschema.Schema{ - "Name": {Type: "string"}, - }, - AdditionalProperties: falseSchema(), - }, - }, - } - if diff := cmp.Diff(wantTools, res.Tools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Fatalf("tools/list mismatch (-want +got):\n%s", diff) - } - + // ListTools is tested in client_list_test.go. gotHi, err := cs.CallTool(ctx, &CallToolParams{ Name: "greet", - Arguments: map[string]any{"name": "user"}, + Arguments: map[string]any{"Name": "user"}, }) if err != nil { t.Fatal(err) @@ -199,7 +234,7 @@ func TestEndToEnd(t *testing.T) { &TextContent{Text: "hi user"}, }, } - if diff := cmp.Diff(wantHi, gotHi); diff != "" { + if diff := cmp.Diff(wantHi, gotHi, ctrCmpOpts...); diff != "" { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) } @@ -218,11 +253,31 @@ func TestEndToEnd(t *testing.T) { &TextContent{Text: errTestFailure.Error()}, }, } - if diff := cmp.Diff(wantFail, gotFail); diff != "" { + if diff := cmp.Diff(wantFail, gotFail, ctrCmpOpts...); diff != "" { t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) } - s.AddTools(&ServerTool{Tool: &Tool{Name: "T"}, Handler: nopHandler}) + // Check output schema validation. + badout := &Tool{ + Name: "badout", + OutputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "x": {Type: "string"}, + }, + }, + } + AddTool(s, badout, func(_ context.Context, _ *CallToolRequest, arg map[string]any) (*CallToolResult, map[string]any, error) { + return nil, map[string]any{"x": 1}, nil + }) + _, err = cs.CallTool(ctx, &CallToolParams{Name: "badout"}) + wantMsg := `has type "integer", want "string"` + if err == nil || !strings.Contains(err.Error(), wantMsg) { + t.Errorf("\ngot %q\nwant error message containing %q", err, wantMsg) + } + + // Check tools-changed notifications. + s.AddTool(&Tool{Name: "T", InputSchema: &jsonschema.Schema{Type: "object"}}, nopHandler) waitForNotification(t, "tools") s.RemoveTools("T") waitForNotification(t, "tools") @@ -246,8 +301,7 @@ func TestEndToEnd(t *testing.T) { MIMEType: "text/template", URITemplate: "file:///{+filename}", // the '+' means that filename can contain '/' } - st := &ServerResourceTemplate{ResourceTemplate: template, Handler: readHandler} - s.AddResourceTemplates(st) + s.AddResourceTemplate(template, readHandler) tres, err := cs.ListResourceTemplates(ctx, nil) if err != nil { t.Fatal(err) @@ -292,7 +346,7 @@ func TestEndToEnd(t *testing.T) { } } - s.AddResources(&ServerResource{Resource: &Resource{URI: "http://U"}}) + s.AddResource(&Resource{URI: "http://U"}, nil) waitForNotification(t, "resources") s.RemoveResources("http://U") waitForNotification(t, "resources") @@ -381,7 +435,7 @@ func TestEndToEnd(t *testing.T) { // Nothing should be logged until the client sets a level. mustLog("info", "before") - if err := cs.SetLevel(ctx, &SetLevelParams{Level: "warning"}); err != nil { + if err := cs.SetLoggingLevel(ctx, &SetLoggingLevelParams{Level: "warning"}); err != nil { t.Fatal(err) } mustLog("warning", want[0].Data) @@ -419,6 +473,49 @@ func TestEndToEnd(t *testing.T) { waitForNotification(t, "progress_server") }) + t.Run("resource_subscriptions", func(t *testing.T) { + err := cs.Subscribe(ctx, &SubscribeParams{ + URI: "test", + }) + if err != nil { + t.Fatal(err) + } + waitForNotification(t, "subscribe") + s.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{ + URI: "test", + }) + waitForNotification(t, "resource_updated") + err = cs.Unsubscribe(ctx, &UnsubscribeParams{ + URI: "test", + }) + if err != nil { + t.Fatal(err) + } + waitForNotification(t, "unsubscribe") + + // Verify the client does not receive the update after unsubscribing. + s.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{ + URI: "test", + }) + select { + case <-notificationChans["resource_updated"]: + t.Fatalf("resource updated after unsubscription") + case <-time.After(time.Second): + } + }) + + t.Run("elicitation", func(t *testing.T) { + result, err := ss.Elicit(ctx, &ElicitParams{ + Message: "Please provide information", + }) + if err != nil { + t.Fatal(err) + } + if result.Action != "accept" { + t.Errorf("got action %q, want %q", result.Action, "accept") + } + }) + // Disconnect. cs.Close() clientWG.Wait() @@ -434,40 +531,6 @@ func TestEndToEnd(t *testing.T) { var ( errTestFailure = errors.New("mcp failure") - tools = map[string]*ServerTool{ - "greet": NewServerTool("greet", "say hi", sayHi), - "fail": { - Tool: &Tool{Name: "fail"}, - Handler: func(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { - return nil, errTestFailure - }, - }, - } - - prompts = map[string]*ServerPrompt{ - "code_review": { - Prompt: &Prompt{ - Name: "code_review", - Description: "do a code review", - Arguments: []*PromptArgument{{Name: "Code", Required: true}}, - }, - Handler: func(_ context.Context, _ *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { - return &GetPromptResult{ - Description: "Code review prompt", - Messages: []*PromptMessage{ - {Role: "user", Content: &TextContent{Text: "Please review the following code: " + params.Arguments["Code"]}}, - }, - }, nil - }, - }, - "fail": { - Prompt: &Prompt{Name: "fail"}, - Handler: func(_ context.Context, _ *ServerSession, _ *GetPromptParams) (*GetPromptResult, error) { - return nil, errTestFailure - }, - }, - } - resource1 = &Resource{ Name: "public", MIMEType: "text/plain", @@ -484,19 +547,14 @@ var ( URI: "embedded:info", } readHandler = fileResourceHandler("testdata/files") - resources = map[string]*ServerResource{ - "info.txt": {resource1, readHandler}, - "fail.txt": {resource2, readHandler}, - "info": {resource3, handleEmbeddedResource}, - } ) var embeddedResources = map[string]string{ "info": "This is the MCP test server.", } -func handleEmbeddedResource(_ context.Context, _ *ServerSession, params *ReadResourceParams) (*ReadResourceResult, error) { - u, err := url.Parse(params.URI) +func handleEmbeddedResource(_ context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { + u, err := url.Parse(req.Params.URI) if err != nil { return nil, err } @@ -510,22 +568,11 @@ func handleEmbeddedResource(_ context.Context, _ *ServerSession, params *ReadRes } return &ReadResourceResult{ Contents: []*ResourceContents{ - {URI: params.URI, MIMEType: "text/plain", Text: text}, + {URI: req.Params.URI, MIMEType: "text/plain", Text: text}, }, }, nil } -// Add calls the given function to add the named features. -func add[T any](m map[string]T, add func(...T), names ...string) { - for _, name := range names { - feat, ok := m[name] - if !ok { - panic("missing feature " + name) - } - add(feat) - } -} - // errorCode returns the code associated with err. // If err is nil, it returns 0. // If there is no code, it returns -1. @@ -540,37 +587,64 @@ func errorCode(err error) int64 { return -1 } -// basicConnection returns a new basic client-server connection configured with -// the provided tools. +// basicConnection returns a new basic client-server connection, with the server +// configured via the provided function. +// +// The caller should cancel either the client connection or server connection +// when the connections are no longer needed. +// +// The returned func cleans up by closing the client and waiting for the server +// to shut down. +func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *ServerSession, func()) { + return basicClientServerConnection(t, nil, nil, config) +} + +// basicClientServerConnection creates a basic connection between client and +// server. If either client or server is nil, empty implementations are used. +// +// The provided function may be used to configure features on the resulting +// server, prior to connection. // // The caller should cancel either the client connection or server connection // when the connections are no longer needed. -func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerSession, *ClientSession) { +// +// The returned func cleans up by closing the client and waiting for the server +// to shut down. +func basicClientServerConnection(t *testing.T, client *Client, server *Server, config func(*Server)) (*ClientSession, *ServerSession, func()) { t.Helper() ctx := context.Background() ct, st := NewInMemoryTransports() - s := NewServer("testServer", "v1.0.0", nil) - - // The 'greet' tool says hi. - s.AddTools(tools...) - ss, err := s.Connect(ctx, st) + if server == nil { + server = NewServer(testImpl, nil) + } + if config != nil { + config(server) + } + ss, err := server.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } - c := NewClient("testClient", "v1.0.0", nil) - cs, err := c.Connect(ctx, ct) + if client == nil { + client = NewClient(testImpl, nil) + } + cs, err := client.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } - return ss, cs + return cs, ss, func() { + cs.Close() + ss.Wait() + } } func TestServerClosing(t *testing.T) { - cc, cs := basicConnection(t, NewServerTool("greet", "say hi", sayHi)) - defer cs.Close() + cs, ss, cleanup := basicConnection(t, func(s *Server) { + AddTool(s, greetTool(), sayHi) + }) + defer cleanup() ctx := context.Background() var wg sync.WaitGroup @@ -583,11 +657,11 @@ func TestServerClosing(t *testing.T) { }() if _, err := cs.CallTool(ctx, &CallToolParams{ Name: "greet", - Arguments: map[string]any{"name": "user"}, + Arguments: map[string]any{"Name": "user"}, }); err != nil { t.Fatalf("after connecting: %v", err) } - cc.Close() + ss.Close() wg.Wait() if _, err := cs.CallTool(ctx, &CallToolParams{ Name: "greet", @@ -601,18 +675,18 @@ func TestBatching(t *testing.T) { ctx := context.Background() ct, st := NewInMemoryTransports() - s := NewServer("testServer", "v1.0.0", nil) - _, err := s.Connect(ctx, st) + s := NewServer(testImpl, nil) + _, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } - c := NewClient("testClient", "v1.0.0", nil) + c := NewClient(testImpl, nil) // TODO: this test is broken, because increasing the batch size here causes // 'initialize' to block. Therefore, we can only test with a size of 1. // Since batching is being removed, we can probably just delete this. const batchSize = 1 - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -640,23 +714,20 @@ func TestCancellation(t *testing.T) { start = make(chan struct{}) cancelled = make(chan struct{}, 1) // don't block the request ) - - slowRequest := func(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { + slowTool := func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { start <- struct{}{} select { case <-ctx.Done(): cancelled <- struct{}{} case <-time.After(5 * time.Second): - return nil, nil + return nil, nil, nil } - return nil, nil + return nil, nil, nil } - st := &ServerTool{ - Tool: &Tool{Name: "slow"}, - Handler: slowRequest, - } - _, cs := basicConnection(t, st) - defer cs.Close() + cs, _, cleanup := basicConnection(t, func(s *Server) { + AddTool(s, &Tool{Name: "slow", InputSchema: &jsonschema.Schema{Type: "object"}}, slowTool) + }) + defer cleanup() ctx, cancel := context.WithCancel(context.Background()) go cs.CallTool(ctx, &CallToolParams{Name: "slow"}) @@ -673,19 +744,16 @@ func TestMiddleware(t *testing.T) { ctx := context.Background() ct, st := NewInMemoryTransports() - s := NewServer("testServer", "v1.0.0", nil) - ss, err := s.Connect(ctx, st) + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } // Wait for the server to exit after the client closes its connection. - var clientWG sync.WaitGroup - clientWG.Add(1) - go func() { + defer func() { if err := ss.Wait(); err != nil { t.Errorf("server failed: %v", err) } - clientWG.Done() }() var sbuf, cbuf bytes.Buffer @@ -697,14 +765,16 @@ func TestMiddleware(t *testing.T) { s.AddSendingMiddleware(traceCalls[*ServerSession](&sbuf, "S1"), traceCalls[*ServerSession](&sbuf, "S2")) s.AddReceivingMiddleware(traceCalls[*ServerSession](&sbuf, "R1"), traceCalls[*ServerSession](&sbuf, "R2")) - c := NewClient("testClient", "v1.0.0", nil) + c := NewClient(testImpl, nil) c.AddSendingMiddleware(traceCalls[*ClientSession](&cbuf, "S1"), traceCalls[*ClientSession](&cbuf, "S2")) c.AddReceivingMiddleware(traceCalls[*ClientSession](&cbuf, "R1"), traceCalls[*ClientSession](&cbuf, "R2")) - cs, err := c.Connect(ctx, ct) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } + defer cs.Close() + if _, err := cs.ListTools(ctx, nil); err != nil { t.Fatal(err) } @@ -780,16 +850,16 @@ func TestNoJSONNull(t *testing.T) { // Collect logs, to sanity check that we don't write JSON null anywhere. var logbuf safeBuffer - ct = NewLoggingTransport(ct, &logbuf) + ct = &LoggingTransport{Transport: ct, Writer: &logbuf} - s := NewServer("testServer", "v1.0.0", nil) - ss, err := s.Connect(ctx, st) + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } - c := NewClient("testClient", "v1.0.0", nil) - cs, err := c.Connect(ctx, ct) + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -822,20 +892,17 @@ func TestNoJSONNull(t *testing.T) { // traceCalls creates a middleware function that prints the method before and after each call // with the given prefix. -func traceCalls[S Session](w io.Writer, prefix string) Middleware[S] { - return func(h MethodHandler[S]) MethodHandler[S] { - return func(ctx context.Context, sess S, method string, params Params) (Result, error) { +func traceCalls[S Session](w io.Writer, prefix string) Middleware { + return func(h MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { fmt.Fprintf(w, "%s >%s\n", prefix, method) defer fmt.Fprintf(w, "%s <%s\n", prefix, method) - return h(ctx, sess, method, params) + return h(ctx, method, req) } } } -// A function, because schemas must form a tree (they have hidden state). -func falseSchema() *jsonschema.Schema { return &jsonschema.Schema{Not: &jsonschema.Schema{}} } - -func nopHandler(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { +func nopHandler(context.Context, *CallToolRequest) (*CallToolResult, error) { return nil, nil } @@ -851,10 +918,10 @@ func TestKeepAlive(t *testing.T) { serverOpts := &ServerOptions{ KeepAlive: 100 * time.Millisecond, } - s := NewServer("testServer", "v1.0.0", serverOpts) - s.AddTools(NewServerTool("greet", "say hi", sayHi)) + s := NewServer(testImpl, serverOpts) + AddTool(s, greetTool(), sayHi) - ss, err := s.Connect(ctx, st) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -863,8 +930,8 @@ func TestKeepAlive(t *testing.T) { clientOpts := &ClientOptions{ KeepAlive: 100 * time.Millisecond, } - c := NewClient("testClient", "v1.0.0", clientOpts) - cs, err := c.Connect(ctx, ct) + c := NewClient(testImpl, clientOpts) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -889,6 +956,518 @@ func TestKeepAlive(t *testing.T) { } } +func TestElicitationUnsupportedMethod(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + // Server + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + // Client without ElicitationHandler + c := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil + }, + }) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Test that elicitation fails when no handler is provided + _, err = ss.Elicit(ctx, &ElicitParams{ + Message: "This should fail", + RequestedSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "test": {Type: "string"}, + }, + }, + }) + + if err == nil { + t.Error("expected error when ElicitationHandler is not provided, got nil") + } + if code := errorCode(err); code != CodeUnsupportedMethod { + t.Errorf("got error code %d, want %d (CodeUnsupportedMethod)", code, CodeUnsupportedMethod) + } + if !strings.Contains(err.Error(), "does not support elicitation") { + t.Errorf("error should mention unsupported elicitation, got: %v", err) + } +} + +func TestElicitationSchemaValidation(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + c := NewClient(testImpl, &ClientOptions{ + ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept", Content: map[string]any{"test": "value"}}, nil + }, + }) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Test valid schemas - these should not return errors + validSchemas := []struct { + name string + schema *jsonschema.Schema + }{ + { + name: "nil schema", + schema: nil, + }, + { + name: "empty object schema", + schema: &jsonschema.Schema{ + Type: "object", + }, + }, + { + name: "simple string property", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string"}, + }, + }, + }, + { + name: "string with valid formats", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "email": {Type: "string", Format: "email"}, + "website": {Type: "string", Format: "uri"}, + "birthday": {Type: "string", Format: "date"}, + "created": {Type: "string", Format: "date-time"}, + }, + }, + }, + { + name: "string with constraints", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string", MinLength: ptr(1), MaxLength: ptr(100)}, + }, + }, + }, + { + name: "number with constraints", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "age": {Type: "integer", Minimum: ptr(0.0), Maximum: ptr(150.0)}, + "score": {Type: "number", Minimum: ptr(0.0), Maximum: ptr(100.0)}, + }, + }, + }, + { + name: "boolean with default", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "enabled": {Type: "boolean", Default: json.RawMessage("true")}, + }, + }, + }, + { + name: "string enum", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "status": { + Type: "string", + Enum: []any{ + "active", + "inactive", + "pending", + }, + }, + }, + }, + }, + { + name: "enum with matching enumNames", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "priority": { + Type: "string", + Enum: []any{ + "high", + "medium", + "low", + }, + Extra: map[string]any{ + "enumNames": []any{"High Priority", "Medium Priority", "Low Priority"}, + }, + }, + }, + }, + }, + } + + for _, tc := range validSchemas { + t.Run("valid_"+tc.name, func(t *testing.T) { + _, err := ss.Elicit(ctx, &ElicitParams{ + Message: "Test valid schema: " + tc.name, + RequestedSchema: tc.schema, + }) + if err != nil { + t.Errorf("expected no error for valid schema %q, got: %v", tc.name, err) + } + }) + } + + // Test invalid schemas - these should return errors + invalidSchemas := []struct { + name string + schema *jsonschema.Schema + expectedError string + }{ + { + name: "root schema non-object type", + schema: &jsonschema.Schema{ + Type: "string", + }, + expectedError: "elicit schema must be of type 'object', got \"string\"", + }, + { + name: "nested object property", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "user": { + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string"}, + }, + }, + }, + }, + expectedError: "elicit schema property \"user\" contains nested properties, only primitive properties are allowed", + }, + { + name: "property with explicit object type", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "config": {Type: "object"}, + }, + }, + expectedError: "elicit schema property \"config\" has unsupported type \"object\", only string, number, integer, and boolean are allowed", + }, + { + name: "array property", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "tags": {Type: "array", Items: &jsonschema.Schema{Type: "string"}}, + }, + }, + expectedError: "elicit schema property \"tags\" has unsupported type \"array\", only string, number, integer, and boolean are allowed", + }, + { + name: "array without items", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "items": {Type: "array"}, + }, + }, + expectedError: "elicit schema property \"items\" has unsupported type \"array\", only string, number, integer, and boolean are allowed", + }, + { + name: "unsupported string format", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "phone": {Type: "string", Format: "phone"}, + }, + }, + expectedError: "elicit schema property \"phone\" has unsupported format \"phone\", only email, uri, date, and date-time are allowed", + }, + { + name: "unsupported type", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "data": {Type: "null"}, + }, + }, + expectedError: "elicit schema property \"data\" has unsupported type \"null\"", + }, + { + name: "string with invalid minLength", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string", MinLength: ptr(-1)}, + }, + }, + expectedError: "elicit schema property \"name\" has invalid minLength -1, must be non-negative", + }, + { + name: "string with invalid maxLength", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string", MaxLength: ptr(-5)}, + }, + }, + expectedError: "elicit schema property \"name\" has invalid maxLength -5, must be non-negative", + }, + { + name: "string with maxLength less than minLength", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string", MinLength: ptr(10), MaxLength: ptr(5)}, + }, + }, + expectedError: "elicit schema property \"name\" has maxLength 5 less than minLength 10", + }, + { + name: "number with maximum less than minimum", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "score": {Type: "number", Minimum: ptr(100.0), Maximum: ptr(50.0)}, + }, + }, + expectedError: "elicit schema property \"score\" has maximum 50 less than minimum 100", + }, + { + name: "boolean with invalid default", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "enabled": {Type: "boolean", Default: json.RawMessage(`"not-a-boolean"`)}, + }, + }, + expectedError: "elicit schema property \"enabled\" has invalid default value, must be a boolean", + }, + { + name: "enum with mismatched enumNames length", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "priority": { + Type: "string", + Enum: []any{ + "high", + "medium", + "low", + }, + Extra: map[string]any{ + "enumNames": []any{"High Priority", "Medium Priority"}, // Only 2 names for 3 values + }, + }, + }, + }, + expectedError: "elicit schema property \"priority\" has 3 enum values but 2 enumNames, they must match", + }, + { + name: "enum with invalid enumNames type", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "status": { + Type: "string", + Enum: []any{ + "active", + "inactive", + }, + Extra: map[string]any{ + "enumNames": "not an array", // Should be array + }, + }, + }, + }, + expectedError: "elicit schema property \"status\" has invalid enumNames type, must be an array", + }, + { + name: "enum without explicit type", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "priority": { + Enum: []any{ + "high", + "medium", + "low", + }, + }, + }, + }, + expectedError: "elicit schema property \"priority\" has unsupported type \"\", only string, number, integer, and boolean are allowed", + }, + { + name: "untyped property", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "data": {}, + }, + }, + expectedError: "elicit schema property \"data\" has unsupported type \"\", only string, number, integer, and boolean are allowed", + }, + } + + for _, tc := range invalidSchemas { + t.Run("invalid_"+tc.name, func(t *testing.T) { + _, err := ss.Elicit(ctx, &ElicitParams{ + Message: "Test invalid schema: " + tc.name, + RequestedSchema: tc.schema, + }) + if err == nil { + t.Errorf("expected error for invalid schema %q, got nil", tc.name) + return + } + if code := errorCode(err); code != CodeInvalidParams { + t.Errorf("got error code %d, want %d (CodeInvalidParams)", code, CodeInvalidParams) + } + if !strings.Contains(err.Error(), tc.expectedError) { + t.Errorf("error message %q does not contain expected text %q", err.Error(), tc.expectedError) + } + }) + } +} + +func TestElicitationProgressToken(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + c := NewClient(testImpl, &ClientOptions{ + ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + params := &ElicitParams{ + Message: "Test progress token", + Meta: Meta{}, + } + params.SetProgressToken("test-token") + + if token := params.GetProgressToken(); token != "test-token" { + t.Errorf("got progress token %v, want %q", token, "test-token") + } + + _, err = ss.Elicit(ctx, params) + if err != nil { + t.Fatal(err) + } +} + +func TestElicitationCapabilityDeclaration(t *testing.T) { + ctx := context.Background() + + t.Run("with_handler", func(t *testing.T) { + ct, st := NewInMemoryTransports() + + // Client with ElicitationHandler should declare capability + c := NewClient(testImpl, &ClientOptions{ + ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "cancel"}, nil + }, + }) + + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // The client should have declared elicitation capability during initialization + // We can verify this worked by successfully making an elicitation call + result, err := ss.Elicit(ctx, &ElicitParams{ + Message: "Test capability", + RequestedSchema: &jsonschema.Schema{Type: "object"}, + }) + if err != nil { + t.Fatalf("elicitation should work when capability is declared, got error: %v", err) + } + if result.Action != "cancel" { + t.Errorf("got action %q, want %q", result.Action, "cancel") + } + }) + + t.Run("without_handler", func(t *testing.T) { + ct, st := NewInMemoryTransports() + + // Client without ElicitationHandler should not declare capability + c := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil + }, + }) + + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Elicitation should fail with UnsupportedMethod + _, err = ss.Elicit(ctx, &ElicitParams{ + Message: "This should fail", + RequestedSchema: &jsonschema.Schema{Type: "object"}, + }) + + if err == nil { + t.Error("expected UnsupportedMethod error when no capability declared") + } + if code := errorCode(err); code != CodeUnsupportedMethod { + t.Errorf("got error code %d, want %d (CodeUnsupportedMethod)", code, CodeUnsupportedMethod) + } + }) +} + func TestKeepAliveFailure(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -896,9 +1475,9 @@ func TestKeepAliveFailure(t *testing.T) { ct, st := NewInMemoryTransports() // Server without keepalive (to test one-sided keepalive) - s := NewServer("testServer", "v1.0.0", nil) - s.AddTools(NewServerTool("greet", "say hi", sayHi)) - ss, err := s.Connect(ctx, st) + s := NewServer(testImpl, nil) + AddTool(s, greetTool(), sayHi) + ss, err := s.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } @@ -907,8 +1486,8 @@ func TestKeepAliveFailure(t *testing.T) { clientOpts := &ClientOptions{ KeepAlive: 50 * time.Millisecond, } - c := NewClient("testClient", "v1.0.0", clientOpts) - cs, err := c.Connect(ctx, ct) + c := NewClient(testImpl, clientOpts) + cs, err := c.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } @@ -936,3 +1515,403 @@ func TestKeepAliveFailure(t *testing.T) { t.Errorf("expected connection to be closed by keepalive, but it wasn't. Last error: %v", err) } + +func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { + // Adding the same tool pointer twice should not panic and should not + // produce duplicates in the server's tool list. + cs, _, cleanup := basicConnection(t, func(s *Server) { + // Use two distinct Tool instances with the same name but different + // descriptions to ensure the second replaces the first + // This case was written specifically to reproduce a bug where duplicate tools where causing jsonschema errors + t1 := &Tool{Name: "dup", Description: "first", InputSchema: &jsonschema.Schema{Type: "object"}} + t2 := &Tool{Name: "dup", Description: "second", InputSchema: &jsonschema.Schema{Type: "object"}} + s.AddTool(t1, nopHandler) + s.AddTool(t2, nopHandler) + }) + defer cleanup() + + ctx := context.Background() + res, err := cs.ListTools(ctx, nil) + if err != nil { + t.Fatal(err) + } + var count int + var gotDesc string + for _, tt := range res.Tools { + if tt.Name == "dup" { + count++ + gotDesc = tt.Description + } + } + if count != 1 { + t.Fatalf("expected exactly one 'dup' tool, got %d", count) + } + if gotDesc != "second" { + t.Fatalf("expected replaced tool to have description %q, got %q", "second", gotDesc) + } +} + +func TestSynchronousNotifications(t *testing.T) { + var toolsChanged atomic.Bool + clientOpts := &ClientOptions{ + ToolListChangedHandler: func(ctx context.Context, req *ToolListChangedRequest) { + toolsChanged.Store(true) + }, + CreateMessageHandler: func(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { + if !toolsChanged.Load() { + return nil, fmt.Errorf("didn't get a tools changed notification") + } + // TODO(rfindley): investigate the error returned from this test if + // CreateMessageResult is new(CreateMessageResult): it's a mysterious + // unmarshalling error that we should improve. + return &CreateMessageResult{Content: &TextContent{}}, nil + }, + } + client := NewClient(testImpl, clientOpts) + + var rootsChanged atomic.Bool + serverOpts := &ServerOptions{ + RootsListChangedHandler: func(_ context.Context, req *RootsListChangedRequest) { + rootsChanged.Store(true) + }, + } + server := NewServer(testImpl, serverOpts) + cs, ss, cleanup := basicClientServerConnection(t, client, server, func(s *Server) { + AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { + if !rootsChanged.Load() { + return nil, nil, fmt.Errorf("didn't get root change notification") + } + return new(CallToolResult), nil, nil + }) + }) + defer cleanup() + + t.Run("from client", func(t *testing.T) { + client.AddRoots(&Root{Name: "myroot", URI: "file://foo"}) + res, err := cs.CallTool(context.Background(), &CallToolParams{Name: "tool"}) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } + if res.IsError { + t.Errorf("tool error: %v", res.Content[0].(*TextContent).Text) + } + }) + + t.Run("from server", func(t *testing.T) { + server.RemoveTools("tool") + if _, err := ss.CreateMessage(context.Background(), new(CreateMessageParams)); err != nil { + t.Errorf("CreateMessage failed: %v", err) + } + }) +} + +func TestNoDistributedDeadlock(t *testing.T) { + // This test verifies that calls are asynchronous, and so it's not possible + // to have a distributed deadlock. + // + // The setup creates potential deadlock for both the client and server: the + // client sends a call to tool1, which itself calls createMessage, which in + // turn calls tool2, which calls ping. + // + // If the server were not asynchronous, the call to tool2 would hang. If the + // client were not asynchronous, the call to ping would hang. + // + // Such a scenario is unlikely in practice, but is still theoretically + // possible, and in any case making tool calls asynchronous by default + // delegates synchronization to the user. + clientOpts := &ClientOptions{ + CreateMessageHandler: func(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { + req.Session.CallTool(ctx, &CallToolParams{Name: "tool2"}) + return &CreateMessageResult{Content: &TextContent{}}, nil + }, + } + client := NewClient(testImpl, clientOpts) + cs, _, cleanup := basicClientServerConnection(t, client, nil, func(s *Server) { + AddTool(s, &Tool{Name: "tool1"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { + req.Session.CreateMessage(ctx, new(CreateMessageParams)) + return new(CallToolResult), nil, nil + }) + AddTool(s, &Tool{Name: "tool2"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { + req.Session.Ping(ctx, nil) + return new(CallToolResult), nil, nil + }) + }) + defer cleanup() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := cs.CallTool(ctx, &CallToolParams{Name: "tool1"}); err != nil { + // should not deadlock + t.Fatalf("CallTool failed: %v", err) + } +} + +var testImpl = &Implementation{Name: "test", Version: "v1.0.0"} + +// This test checks that when we use pointer types for tools, we get the same +// schema as when using the non-pointer types. It is too much of a footgun for +// there to be a difference (see #199 and #200). +// +// If anyone asks, we can add an option that controls how pointers are treated. +func TestPointerArgEquivalence(t *testing.T) { + type input struct { + In string `json:",omitempty"` + } + type output struct { + Out string + } + cs, _, cleanup := basicConnection(t, func(s *Server) { + // Add two equivalent tools, one of which operates in the 'pointer' realm, + // the other of which does not. + // + // We handle a few different types of results, to assert they behave the + // same in all cases. + AddTool(s, &Tool{Name: "pointer"}, func(_ context.Context, req *CallToolRequest, in *input) (*CallToolResult, *output, error) { + switch in.In { + case "": + return nil, nil, fmt.Errorf("must provide input") + case "nil": + return nil, nil, nil + case "empty": + return &CallToolResult{}, nil, nil + case "ok": + return &CallToolResult{}, &output{Out: "foo"}, nil + default: + panic("unreachable") + } + }) + AddTool(s, &Tool{Name: "nonpointer"}, func(_ context.Context, req *CallToolRequest, in input) (*CallToolResult, output, error) { + switch in.In { + case "": + return nil, output{}, fmt.Errorf("must provide input") + case "nil": + return nil, output{}, nil + case "empty": + return &CallToolResult{}, output{}, nil + case "ok": + return &CallToolResult{}, output{Out: "foo"}, nil + default: + panic("unreachable") + } + }) + }) + defer cleanup() + + ctx := context.Background() + tools, err := cs.ListTools(ctx, nil) + if err != nil { + t.Fatal(err) + } + if got, want := len(tools.Tools), 2; got != want { + t.Fatalf("got %d tools, want %d", got, want) + } + t0 := tools.Tools[0] + t1 := tools.Tools[1] + + // First, check that the tool schemas don't differ. + if diff := cmp.Diff(t0.InputSchema, t1.InputSchema); diff != "" { + t.Errorf("input schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff) + } + if diff := cmp.Diff(t0.OutputSchema, t1.OutputSchema); diff != "" { + t.Errorf("output schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff) + } + + // Then, check that we handle empty input equivalently. + for _, args := range []any{nil, struct{}{}} { + r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: args}) + if err != nil { + t.Fatal(err) + } + r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: args}) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" { + t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff) + } + } + + // Then, check that we handle different types of output equivalently. + for _, in := range []string{"nil", "empty", "ok"} { + t.Run(in, func(t *testing.T) { + r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: input{In: in}}) + if err != nil { + t.Fatal(err) + } + r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: input{In: in}}) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" { + t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff) + } + }) + } +} + +// ptr is a helper function to create pointers for schema constraints +func ptr[T any](v T) *T { + return &v +} + +func TestComplete(t *testing.T) { + completionValues := []string{"python", "pytorch", "pyside"} + + serverOpts := &ServerOptions{ + CompletionHandler: func(_ context.Context, request *CompleteRequest) (*CompleteResult, error) { + return &CompleteResult{ + Completion: CompletionResultDetails{ + Values: completionValues, + }, + }, nil + }, + } + server := NewServer(testImpl, serverOpts) + cs, _, cleanup := basicClientServerConnection(t, nil, server, func(s *Server) {}) + defer cleanup() + + result, err := cs.Complete(context.Background(), &CompleteParams{ + Argument: CompleteParamsArgument{ + Name: "language", + Value: "py", + }, + Ref: &CompleteReference{ + Type: "ref/prompt", + Name: "code_review", + }, + }) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(completionValues, result.Completion.Values); diff != "" { + t.Errorf("Complete() mismatch (-want +got):\n%s", diff) + } +} + +// TestEmbeddedStructResponse performs a tool call to verify that a struct with +// an embedded pointer generates a correct, flattened JSON schema and that its +// response is validated successfully. +func TestEmbeddedStructResponse(t *testing.T) { + type foo struct { + ID string `json:"id"` + Name string `json:"name"` + } + + // bar embeds foo + type bar struct { + *foo // Embedded - should flatten in JSON + Extra string `json:"extra"` + } + + type response struct { + Data bar `json:"data"` + } + + // testTool demonstrates an embedded struct in its response. + testTool := func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, response, error) { + response := response{ + Data: bar{ + foo: &foo{ + ID: "foo", + Name: "Test Foo", + }, + Extra: "additional data", + }, + } + return nil, response, nil + } + ctx := context.Background() + clientTransport, serverTransport := NewInMemoryTransports() + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + AddTool(server, &Tool{ + Name: "test_embedded_struct", + }, testTool) + + serverSession, err := server.Connect(ctx, serverTransport, nil) + if err != nil { + t.Fatal(err) + } + defer serverSession.Close() + + client := NewClient(&Implementation{Name: "test-client"}, nil) + clientSession, err := client.Connect(ctx, clientTransport, nil) + if err != nil { + t.Fatal(err) + } + defer clientSession.Close() + + _, err = clientSession.CallTool(ctx, &CallToolParams{ + Name: "test_embedded_struct", + }) + if err != nil { + t.Errorf("CallTool() failed: %v", err) + } +} + +func TestToolErrorMiddleware(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + s := NewServer(testImpl, nil) + AddTool(s, &Tool{ + Name: "greet", + Description: "say hi", + }, sayHi) + AddTool(s, &Tool{Name: "fail", InputSchema: &jsonschema.Schema{Type: "object"}}, + func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) { + return nil, nil, errTestFailure + }) + + var middleErr error + s.AddReceivingMiddleware(func(h MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + res, err := h(ctx, method, req) + if err == nil { + if ctr, ok := res.(*CallToolResult); ok { + middleErr = ctr.getError() + } + } + return res, err + } + }) + _, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + client := NewClient(&Implementation{Name: "test-client"}, nil) + clientSession, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer clientSession.Close() + + _, err = clientSession.CallTool(ctx, &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"Name": "al"}, + }) + if err != nil { + t.Errorf("CallTool() failed: %v", err) + } + if middleErr != nil { + t.Errorf("middleware got error %v, want nil", middleErr) + } + res, err := clientSession.CallTool(ctx, &CallToolParams{ + Name: "fail", + }) + if err != nil { + t.Errorf("CallTool() failed: %v", err) + } + if !res.IsError { + t.Fatal("want error, got none") + } + // Clients can't see the error, because it isn't marshaled. + if err := res.getError(); err != nil { + t.Fatalf("got %v, want nil", err) + } + if middleErr != errTestFailure { + t.Errorf("middleware got err %v, want errTestFailure", middleErr) + } +} + +var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(CallToolResult{})} diff --git a/mcp/prompt.go b/mcp/prompt.go index e2db7b27..62f38a36 100644 --- a/mcp/prompt.go +++ b/mcp/prompt.go @@ -9,10 +9,9 @@ import ( ) // A PromptHandler handles a call to prompts/get. -type PromptHandler func(context.Context, *ServerSession, *GetPromptParams) (*GetPromptResult, error) +type PromptHandler func(context.Context, *GetPromptRequest) (*GetPromptResult, error) -// A Prompt is a prompt definition bound to a prompt handler. -type ServerPrompt struct { - Prompt *Prompt - Handler PromptHandler +type serverPrompt struct { + prompt *Prompt + handler PromptHandler } diff --git a/mcp/protocol.go b/mcp/protocol.go index eba9e73d..f3f23f58 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -14,7 +14,7 @@ import ( "encoding/json" "fmt" - "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/google/jsonschema-go/jsonschema" ) // Optional annotations for the client. The client can use annotations to inform @@ -40,48 +40,104 @@ type Annotations struct { Priority float64 `json:"priority,omitempty"` } -type CallToolParams = CallToolParamsFor[any] +// CallToolParams is used by clients to call a tool. +type CallToolParams struct { + // Meta is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // Name is the name of the tool to call. + Name string `json:"name"` + // Arguments holds the tool arguments. It can hold any value that can be + // marshaled to JSON. + Arguments any `json:"arguments,omitempty"` +} -type CallToolParamsFor[In any] struct { +// CallToolParamsRaw is passed to tool handlers on the server. Its arguments +// are not yet unmarshaled (hence "raw"), so that the handlers can perform +// unmarshaling themselves. +type CallToolParamsRaw struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. - Meta `json:"_meta,omitempty"` - Name string `json:"name"` - Arguments In `json:"arguments,omitempty"` + Meta `json:"_meta,omitempty"` + // Name is the name of the tool being called. + Name string `json:"name"` + // Arguments is the raw arguments received over the wire from the client. It + // is the responsibility of the tool handler to unmarshal and validate the + // Arguments (see [AddTool]). + Arguments json.RawMessage `json:"arguments,omitempty"` } -// The server's response to a tool call. -type CallToolResult = CallToolResultFor[any] - -type CallToolResultFor[Out any] struct { +// A CallToolResult is the server's response to a tool call. +// +// The [ToolHandler] and [ToolHandlerFor] handler functions return this result, +// though [ToolHandlerFor] populates much of it automatically as documented at +// each field. +type CallToolResult struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` + // A list of content objects that represent the unstructured result of the tool // call. + // + // When using a [ToolHandlerFor] with structured output, if Content is unset + // it will be populated with JSON text content corresponding to the + // structured output value. Content []Content `json:"content"` - // An optional JSON object that represents the structured result of the tool - // call. - StructuredContent Out `json:"structuredContent,omitempty"` - // Whether the tool call ended in an error. + + // StructuredContent is an optional value that represents the structured + // result of the tool call. It must marshal to a JSON object. + // + // When using a [ToolHandlerFor] with structured output, you should not + // populate this field. It will be automatically populated with the typed Out + // value. + StructuredContent any `json:"structuredContent,omitempty"` + + // IsError reports whether the tool call ended in an error. // // If not set, this is assumed to be false (the call was successful). // - // Any errors that originate from the tool should be reported inside the result - // object, with isError set to true, not as an MCP protocol-level error - // response. Otherwise, the LLM would not be able to see that an error occurred - // and self-correct. + // Any errors that originate from the tool should be reported inside the + // Content field, with IsError set to true, not as an MCP protocol-level + // error response. Otherwise, the LLM would not be able to see that an error + // occurred and self-correct. // // However, any errors in finding the tool, an error indicating that the // server does not support tool calls, or any other exceptional conditions, // should be reported as an MCP error response. + // + // When using a [ToolHandlerFor], this field is automatically set when the + // tool handler returns an error, and the error string is included as text in + // the Content field. IsError bool `json:"isError,omitempty"` + + // The error passed to setError, if any. + // It is not marshaled, and therefore it is only visible on the server. + // Its only use is in server sending middleware, where it can be accessed + // with getError. + err error +} + +// TODO(#64): consider exposing setError (and getError), by adding an error +// field on CallToolResult. +func (r *CallToolResult) setError(err error) { + r.Content = []Content{&TextContent{Text: err.Error()}} + r.IsError = true + r.err = err +} + +// getError returns the error set with setError, or nil if none. +// This function always returns nil on clients. +func (r *CallToolResult) getError() error { + return r.err } +func (*CallToolResult) isResult() {} + // UnmarshalJSON handles the unmarshalling of content into the Content // interface. -func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error { - type res CallToolResultFor[Out] // avoid recursion +func (x *CallToolResult) UnmarshalJSON(data []byte) error { + type res CallToolResult // avoid recursion var wire struct { res Content []*wireContent `json:"content"` @@ -93,12 +149,17 @@ func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error { if wire.res.Content, err = contentsFromWire(wire.Content, nil); err != nil { return err } - *x = CallToolResultFor[Out](wire.res) + *x = CallToolResult(wire.res) return nil } -func (x *CallToolParamsFor[Out]) GetProgressToken() any { return getProgressToken(x) } -func (x *CallToolParamsFor[Out]) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *CallToolParams) isParams() {} +func (x *CallToolParams) GetProgressToken() any { return getProgressToken(x) } +func (x *CallToolParams) SetProgressToken(t any) { setProgressToken(x, t) } + +func (x *CallToolParamsRaw) isParams() {} +func (x *CallToolParamsRaw) GetProgressToken() any { return getProgressToken(x) } +func (x *CallToolParamsRaw) SetProgressToken(t any) { setProgressToken(x, t) } type CancelledParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -114,6 +175,7 @@ type CancelledParams struct { RequestID any `json:"requestId"` } +func (x *CancelledParams) isParams() {} func (x *CancelledParams) GetProgressToken() any { return getProgressToken(x) } func (x *CancelledParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -122,7 +184,7 @@ func (x *CancelledParams) SetProgressToken(t any) { setProgressToken(x, t) } // additional capabilities. type ClientCapabilities struct { // Experimental, non-standard capabilities that the client supports. - Experimental map[string]struct{} `json:"experimental,omitempty"` + Experimental map[string]any `json:"experimental,omitempty"` // Present if the client supports listing roots. Roots struct { // Whether the client supports notifications for changes to the roots list. @@ -207,6 +269,8 @@ type CompleteParams struct { Ref *CompleteReference `json:"ref"` } +func (*CompleteParams) isParams() {} + type CompletionResultDetails struct { HasMore bool `json:"hasMore,omitempty"` Total int `json:"total,omitempty"` @@ -221,6 +285,8 @@ type CompleteResult struct { Completion CompletionResultDetails `json:"completion"` } +func (*CompleteResult) isResult() {} + type CreateMessageParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. @@ -234,7 +300,7 @@ type CreateMessageParams struct { Messages []*SamplingMessage `json:"messages"` // Optional metadata to pass through to the LLM provider. The format of this // metadata is provider-specific. - Metadata struct{} `json:"metadata,omitempty"` + Metadata any `json:"metadata,omitempty"` // The server's preferences for which model to select. The client may ignore // these preferences. ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` @@ -245,6 +311,7 @@ type CreateMessageParams struct { Temperature float64 `json:"temperature,omitempty"` } +func (x *CreateMessageParams) isParams() {} func (x *CreateMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -264,6 +331,24 @@ type CreateMessageResult struct { StopReason string `json:"stopReason,omitempty"` } +func (*CreateMessageResult) isResult() {} +func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { + type result CreateMessageResult // avoid recursion + var wire struct { + result + Content *wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.result.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil { + return err + } + *r = CreateMessageResult(wire.result) + return nil +} + type GetPromptParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. @@ -274,6 +359,7 @@ type GetPromptParams struct { Name string `json:"name"` } +func (x *GetPromptParams) isParams() {} func (x *GetPromptParams) GetProgressToken() any { return getProgressToken(x) } func (x *GetPromptParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -287,17 +373,20 @@ type GetPromptResult struct { Messages []*PromptMessage `json:"messages"` } +func (*GetPromptResult) isResult() {} + type InitializeParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` Capabilities *ClientCapabilities `json:"capabilities"` - ClientInfo *implementation `json:"clientInfo"` + ClientInfo *Implementation `json:"clientInfo"` // The latest version of the Model Context Protocol that the client supports. // The client may decide to support older versions as well. ProtocolVersion string `json:"protocolVersion"` } +func (x *InitializeParams) isParams() {} func (x *InitializeParams) GetProgressToken() any { return getProgressToken(x) } func (x *InitializeParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -307,7 +396,7 @@ type InitializeResult struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` - Capabilities *serverCapabilities `json:"capabilities"` + Capabilities *ServerCapabilities `json:"capabilities"` // Instructions describing how to use the server and its features. // // This can be used by clients to improve the LLM's understanding of available @@ -318,15 +407,18 @@ type InitializeResult struct { // may not match the version that the client requested. If the client cannot // support this version, it must disconnect. ProtocolVersion string `json:"protocolVersion"` - ServerInfo *implementation `json:"serverInfo"` + ServerInfo *Implementation `json:"serverInfo"` } +func (*InitializeResult) isResult() {} + type InitializedParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` } +func (x *InitializedParams) isParams() {} func (x *InitializedParams) GetProgressToken() any { return getProgressToken(x) } func (x *InitializedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -339,6 +431,7 @@ type ListPromptsParams struct { Cursor string `json:"cursor,omitempty"` } +func (x *ListPromptsParams) isParams() {} func (x *ListPromptsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListPromptsParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListPromptsParams) cursorPtr() *string { return &x.Cursor } @@ -354,6 +447,7 @@ type ListPromptsResult struct { Prompts []*Prompt `json:"prompts"` } +func (x *ListPromptsResult) isResult() {} func (x *ListPromptsResult) nextCursorPtr() *string { return &x.NextCursor } type ListResourceTemplatesParams struct { @@ -365,6 +459,7 @@ type ListResourceTemplatesParams struct { Cursor string `json:"cursor,omitempty"` } +func (x *ListResourceTemplatesParams) isParams() {} func (x *ListResourceTemplatesParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListResourceTemplatesParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListResourceTemplatesParams) cursorPtr() *string { return &x.Cursor } @@ -380,6 +475,7 @@ type ListResourceTemplatesResult struct { ResourceTemplates []*ResourceTemplate `json:"resourceTemplates"` } +func (x *ListResourceTemplatesResult) isResult() {} func (x *ListResourceTemplatesResult) nextCursorPtr() *string { return &x.NextCursor } type ListResourcesParams struct { @@ -391,6 +487,7 @@ type ListResourcesParams struct { Cursor string `json:"cursor,omitempty"` } +func (x *ListResourcesParams) isParams() {} func (x *ListResourcesParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListResourcesParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListResourcesParams) cursorPtr() *string { return &x.Cursor } @@ -406,6 +503,7 @@ type ListResourcesResult struct { Resources []*Resource `json:"resources"` } +func (x *ListResourcesResult) isResult() {} func (x *ListResourcesResult) nextCursorPtr() *string { return &x.NextCursor } type ListRootsParams struct { @@ -414,6 +512,7 @@ type ListRootsParams struct { Meta `json:"_meta,omitempty"` } +func (x *ListRootsParams) isParams() {} func (x *ListRootsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListRootsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -427,6 +526,8 @@ type ListRootsResult struct { Roots []*Root `json:"roots"` } +func (*ListRootsResult) isResult() {} + type ListToolsParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. @@ -436,6 +537,7 @@ type ListToolsParams struct { Cursor string `json:"cursor,omitempty"` } +func (x *ListToolsParams) isParams() {} func (x *ListToolsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListToolsParams) cursorPtr() *string { return &x.Cursor } @@ -451,6 +553,7 @@ type ListToolsResult struct { Tools []*Tool `json:"tools"` } +func (x *ListToolsResult) isResult() {} func (x *ListToolsResult) nextCursorPtr() *string { return &x.NextCursor } // The severity of a log message. @@ -472,6 +575,7 @@ type LoggingMessageParams struct { Logger string `json:"logger,omitempty"` } +func (x *LoggingMessageParams) isParams() {} func (x *LoggingMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *LoggingMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -533,6 +637,7 @@ type PingParams struct { Meta `json:"_meta,omitempty"` } +func (x *PingParams) isParams() {} func (x *PingParams) GetProgressToken() any { return getProgressToken(x) } func (x *PingParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -540,18 +645,21 @@ type ProgressNotificationParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` + // The progress token which was given in the initial request, used to associate + // this notification with the request that is proceeding. + ProgressToken any `json:"progressToken"` // An optional message describing the current progress. Message string `json:"message,omitempty"` // The progress thus far. This should increase every time progress is made, even // if the total is unknown. Progress float64 `json:"progress"` - // The progress token which was given in the initial request, used to associate - // this notification with the request that is proceeding. - ProgressToken any `json:"progressToken"` // Total number of items to process (or total progress required), if known. + // Zero means unknown. Total float64 `json:"total,omitempty"` } +func (*ProgressNotificationParams) isParams() {} + // A prompt or prompt template that the server offers. type Prompt struct { // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta @@ -589,6 +697,7 @@ type PromptListChangedParams struct { Meta `json:"_meta,omitempty"` } +func (x *PromptListChangedParams) isParams() {} func (x *PromptListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *PromptListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -629,6 +738,7 @@ type ReadResourceParams struct { URI string `json:"uri"` } +func (x *ReadResourceParams) isParams() {} func (x *ReadResourceParams) GetProgressToken() any { return getProgressToken(x) } func (x *ReadResourceParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -640,6 +750,8 @@ type ReadResourceResult struct { Contents []*ResourceContents `json:"contents"` } +func (*ReadResourceResult) isResult() {} + // A known resource that the server is capable of reading. type Resource struct { // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta @@ -680,6 +792,7 @@ type ResourceListChangedParams struct { Meta `json:"_meta,omitempty"` } +func (x *ResourceListChangedParams) isParams() {} func (x *ResourceListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ResourceListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -737,6 +850,7 @@ type RootsListChangedParams struct { Meta `json:"_meta,omitempty"` } +func (x *RootsListChangedParams) isParams() {} func (x *RootsListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -771,7 +885,7 @@ func (m *SamplingMessage) UnmarshalJSON(data []byte) error { return nil } -type SetLevelParams struct { +type SetLoggingLevelParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` @@ -781,8 +895,9 @@ type SetLevelParams struct { Level LoggingLevel `json:"level"` } -func (x *SetLevelParams) GetProgressToken() any { return getProgressToken(x) } -func (x *SetLevelParams) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *SetLoggingLevelParams) isParams() {} +func (x *SetLoggingLevelParams) GetProgressToken() any { return getProgressToken(x) } +func (x *SetLoggingLevelParams) SetProgressToken(t any) { setProgressToken(x, t) } // Definition for a tool the client can call. type Tool struct { @@ -856,16 +971,87 @@ type ToolListChangedParams struct { Meta `json:"_meta,omitempty"` } +func (x *ToolListChangedParams) isParams() {} func (x *ToolListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } +// Sent from the client to request resources/updated notifications from the +// server whenever a particular resource changes. +type SubscribeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to subscribe to. + URI string `json:"uri"` +} + +func (*SubscribeParams) isParams() {} + +// Sent from the client to request cancellation of resources/updated +// notifications from the server. This should follow a previous +// resources/subscribe request. +type UnsubscribeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to unsubscribe from. + URI string `json:"uri"` +} + +func (*UnsubscribeParams) isParams() {} + +// A notification from the server to the client, informing it that a resource +// has changed and may need to be read again. This should only be sent if the +// client previously sent a resources/subscribe request. +type ResourceUpdatedNotificationParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to. + URI string `json:"uri"` +} + +func (*ResourceUpdatedNotificationParams) isParams() {} + // TODO(jba): add CompleteRequest and related types. -// TODO(jba): add ElicitRequest and related types. +// A request from the server to elicit additional information from the user via the client. +type ElicitParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The message to present to the user. + Message string `json:"message"` + // A restricted subset of JSON Schema. + // Only top-level properties are allowed, without nesting. + RequestedSchema *jsonschema.Schema `json:"requestedSchema"` +} + +func (x *ElicitParams) isParams() {} + +func (x *ElicitParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ElicitParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The client's response to an elicitation/create request from the server. +type ElicitResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The user action in response to the elicitation. + // - "accept": User submitted the form/confirmed the action + // - "decline": User explicitly declined the action + // - "cancel": User dismissed without making an explicit choice + Action string `json:"action"` + // The submitted form data, only present when action is "accept". + // Contains values matching the requested schema. + Content map[string]any `json:"content,omitempty"` +} + +func (*ElicitResult) isResult() {} -// Describes the name and version of an MCP implementation, with an optional +// An Implementation describes the name and version of an MCP implementation, with an optional // title for UI representation. -type implementation struct { +type Implementation struct { // Intended for programmatic or logical use, but used as a display name in past // specs or fallback (if title isn't present). Name string `json:"name"` @@ -876,19 +1062,19 @@ type implementation struct { } // Present if the server supports argument autocompletion suggestions. -type completionCapabilities struct{} +type CompletionCapabilities struct{} // Present if the server supports sending log messages to the client. -type loggingCapabilities struct{} +type LoggingCapabilities struct{} // Present if the server offers any prompt templates. -type promptCapabilities struct { +type PromptCapabilities struct { // Whether this server supports notifications for changes to the prompt list. ListChanged bool `json:"listChanged,omitempty"` } // Present if the server offers any resources to read. -type resourceCapabilities struct { +type ResourceCapabilities struct { // Whether this server supports notifications for changes to the resource list. ListChanged bool `json:"listChanged,omitempty"` // Whether this server supports subscribing to resource updates. @@ -898,23 +1084,23 @@ type resourceCapabilities struct { // Capabilities that a server may support. Known capabilities are defined here, // in this schema, but this is not a closed set: any server can define its own, // additional capabilities. -type serverCapabilities struct { +type ServerCapabilities struct { // Present if the server supports argument autocompletion suggestions. - Completions *completionCapabilities `json:"completions,omitempty"` + Completions *CompletionCapabilities `json:"completions,omitempty"` // Experimental, non-standard capabilities that the server supports. - Experimental map[string]struct{} `json:"experimental,omitempty"` + Experimental map[string]any `json:"experimental,omitempty"` // Present if the server supports sending log messages to the client. - Logging *loggingCapabilities `json:"logging,omitempty"` + Logging *LoggingCapabilities `json:"logging,omitempty"` // Present if the server offers any prompt templates. - Prompts *promptCapabilities `json:"prompts,omitempty"` + Prompts *PromptCapabilities `json:"prompts,omitempty"` // Present if the server offers any resources to read. - Resources *resourceCapabilities `json:"resources,omitempty"` + Resources *ResourceCapabilities `json:"resources,omitempty"` // Present if the server offers any tools to call. - Tools *toolCapabilities `json:"tools,omitempty"` + Tools *ToolCapabilities `json:"tools,omitempty"` } // Present if the server offers any tools to call. -type toolCapabilities struct { +type ToolCapabilities struct { // Whether this server supports notifications for changes to the tool list. ListChanged bool `json:"listChanged,omitempty"` } diff --git a/mcp/protocol_test.go b/mcp/protocol_test.go index dba80a8b..67d021d1 100644 --- a/mcp/protocol_test.go +++ b/mcp/protocol_test.go @@ -208,6 +208,7 @@ func TestCompleteReference(t *testing.T) { }) } } + func TestCompleteParams(t *testing.T) { // Define test cases specifically for Marshalling marshalTests := []struct { @@ -498,7 +499,7 @@ func TestContentUnmarshal(t *testing.T) { if err := json.Unmarshal(data, out); err != nil { t.Fatal(err) } - if diff := cmp.Diff(in, out); diff != "" { + if diff := cmp.Diff(in, out, ctrCmpOpts...); diff != "" { t.Errorf("mismatch (-want, +got):\n%s", diff) } } @@ -514,13 +515,13 @@ func TestContentUnmarshal(t *testing.T) { var got CallToolResult roundtrip(ctr, &got) - ctrf := &CallToolResultFor[int]{ + ctrf := &CallToolResult{ Meta: Meta{"m": true}, Content: content, IsError: true, - StructuredContent: 3, + StructuredContent: 3.0, } - var gotf CallToolResultFor[int] + var gotf CallToolResult roundtrip(ctrf, &gotf) pm := &PromptMessage{ diff --git a/mcp/requests.go b/mcp/requests.go new file mode 100644 index 00000000..82b700f5 --- /dev/null +++ b/mcp/requests.go @@ -0,0 +1,37 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file holds the request types. + +package mcp + +type ( + CallToolRequest = ServerRequest[*CallToolParamsRaw] + CompleteRequest = ServerRequest[*CompleteParams] + GetPromptRequest = ServerRequest[*GetPromptParams] + InitializedRequest = ServerRequest[*InitializedParams] + ListPromptsRequest = ServerRequest[*ListPromptsParams] + ListResourcesRequest = ServerRequest[*ListResourcesParams] + ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams] + ListToolsRequest = ServerRequest[*ListToolsParams] + ProgressNotificationServerRequest = ServerRequest[*ProgressNotificationParams] + ReadResourceRequest = ServerRequest[*ReadResourceParams] + RootsListChangedRequest = ServerRequest[*RootsListChangedParams] + SubscribeRequest = ServerRequest[*SubscribeParams] + UnsubscribeRequest = ServerRequest[*UnsubscribeParams] +) + +type ( + CreateMessageRequest = ClientRequest[*CreateMessageParams] + ElicitRequest = ClientRequest[*ElicitParams] + initializedClientRequest = ClientRequest[*InitializedParams] + InitializeRequest = ClientRequest[*InitializeParams] + ListRootsRequest = ClientRequest[*ListRootsParams] + LoggingMessageRequest = ClientRequest[*LoggingMessageParams] + ProgressNotificationClientRequest = ClientRequest[*ProgressNotificationParams] + PromptListChangedRequest = ClientRequest[*PromptListChangedParams] + ResourceListChangedRequest = ClientRequest[*ResourceListChangedParams] + ResourceUpdatedNotificationRequest = ClientRequest[*ResourceUpdatedNotificationParams] + ToolListChangedRequest = ClientRequest[*ToolListChangedParams] +) diff --git a/mcp/resource.go b/mcp/resource.go index 18e0bec4..0658c661 100644 --- a/mcp/resource.go +++ b/mcp/resource.go @@ -13,29 +13,29 @@ import ( "net/url" "os" "path/filepath" - "regexp" "strings" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/util" + "github.com/yosida95/uritemplate/v3" ) -// A ServerResource associates a Resource with its handler. -type ServerResource struct { - Resource *Resource - Handler ResourceHandler +// A serverResource associates a Resource with its handler. +type serverResource struct { + resource *Resource + handler ResourceHandler } -// A ServerResourceTemplate associates a ResourceTemplate with its handler. -type ServerResourceTemplate struct { - ResourceTemplate *ResourceTemplate - Handler ResourceHandler +// A serverResourceTemplate associates a ResourceTemplate with its handler. +type serverResourceTemplate struct { + resourceTemplate *ResourceTemplate + handler ResourceHandler } // A ResourceHandler is a function that reads a resource. // It will be called when the client calls [ClientSession.ReadResource]. // If it cannot find the resource, it should return the result of calling [ResourceNotFoundError]. -type ResourceHandler func(context.Context, *ServerSession, *ReadResourceParams) (*ReadResourceResult, error) +type ResourceHandler func(context.Context, *ReadResourceRequest) (*ReadResourceResult, error) // ResourceNotFoundError returns an error indicating that a resource being read could // not be found. @@ -155,67 +155,10 @@ func fileRoot(root *Root) (_ string, err error) { } // Matches reports whether the receiver's uri template matches the uri. -// TODO: use "github.com/yosida95/uritemplate/v3" -func (sr *ServerResourceTemplate) Matches(uri string) bool { - re, err := uriTemplateToRegexp(sr.ResourceTemplate.URITemplate) +func (sr *serverResourceTemplate) Matches(uri string) bool { + tmpl, err := uritemplate.New(sr.resourceTemplate.URITemplate) if err != nil { return false } - return re.MatchString(uri) -} - -func uriTemplateToRegexp(uriTemplate string) (*regexp.Regexp, error) { - pat := uriTemplate - var b strings.Builder - b.WriteByte('^') - seen := map[string]bool{} - for len(pat) > 0 { - literal, rest, ok := strings.Cut(pat, "{") - b.WriteString(regexp.QuoteMeta(literal)) - if !ok { - break - } - expr, rest, ok := strings.Cut(rest, "}") - if !ok { - return nil, errors.New("missing '}'") - } - pat = rest - if strings.ContainsRune(expr, ',') { - return nil, errors.New("can't handle commas in expressions") - } - if strings.ContainsRune(expr, ':') { - return nil, errors.New("can't handle prefix modifiers in expressions") - } - if len(expr) > 0 && expr[len(expr)-1] == '*' { - return nil, errors.New("can't handle explode modifiers in expressions") - } - - // These sets of valid characters aren't accurate. - // See https://datatracker.ietf.org/doc/html/rfc6570. - var re, name string - first := byte(0) - if len(expr) > 0 { - first = expr[0] - } - switch first { - default: - // {var} doesn't match slashes. (It should also fail to match other characters, - // but this simplified implementation doesn't handle that.) - re = `[^/]*` - name = expr - case '+': - // {+var} matches anything, even slashes - re = `.*` - name = expr[1:] - case '#', '.', '/', ';', '?', '&': - return nil, fmt.Errorf("prefix character %c unsupported", first) - } - if seen[name] { - return nil, fmt.Errorf("can't handle duplicate name %q", name) - } - seen[name] = true - b.WriteString(re) - } - b.WriteByte('$') - return regexp.Compile(b.String()) + return tmpl.Regexp().MatchString(uri) } diff --git a/mcp/resource_test.go b/mcp/resource_test.go index cb5e4fb9..74d82f12 100644 --- a/mcp/resource_test.go +++ b/mcp/resource_test.go @@ -119,18 +119,14 @@ func TestTemplateMatch(t *testing.T) { template string want bool }{ - {"file:///{}/{a}/{b}", true}, + {"file:///{}/{a}/{b}", false}, // invalid: empty variable expression "{}" is not allowed in RFC 6570 {"file:///{a}/{b}", false}, {"file:///{+path}", true}, {"file:///{a}/{+path}", true}, } { - re, err := uriTemplateToRegexp(tt.template) - if err != nil { - t.Fatalf("%s: %v", tt.template, err) - } - got := re.MatchString(uri) - if got != tt.want { - t.Errorf("%s: got %t, want %t", tt.template, got, tt.want) + resourceTmpl := serverResourceTemplate{resourceTemplate: &ResourceTemplate{URITemplate: tt.template}} + if matched := resourceTmpl.Matches(uri); matched != tt.want { + t.Errorf("%s: got %t, want %t", tt.template, matched, tt.want) } } } diff --git a/mcp/root.go b/mcp/root.go deleted file mode 100644 index b56ad991..00000000 --- a/mcp/root.go +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package mcp diff --git a/mcp/server.go b/mcp/server.go index 69666a6a..27de09a3 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -12,15 +12,19 @@ import ( "encoding/json" "fmt" "iter" + "maps" "net/url" "path/filepath" + "reflect" "slices" "sync" "time" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/util" - "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/yosida95/uritemplate/v3" ) const DefaultPageSize = 1000 @@ -28,21 +32,21 @@ const DefaultPageSize = 1000 // A Server is an instance of an MCP server. // // Servers expose server-side MCP features, which can serve one or more MCP -// sessions by using [Server.Start] or [Server.Run]. +// sessions by using [Server.Run]. type Server struct { // fixed at creation - name string - version string - opts ServerOptions + impl *Implementation + opts ServerOptions mu sync.Mutex - prompts *featureSet[*ServerPrompt] - tools *featureSet[*ServerTool] - resources *featureSet[*ServerResource] - resourceTemplates *featureSet[*ServerResourceTemplate] + prompts *featureSet[*serverPrompt] + tools *featureSet[*serverTool] + resources *featureSet[*serverResource] + resourceTemplates *featureSet[*serverResourceTemplate] sessions []*ServerSession - sendingMethodHandler_ MethodHandler[*ServerSession] - receivingMethodHandler_ MethodHandler[*ServerSession] + sendingMethodHandler_ MethodHandler + receivingMethodHandler_ MethodHandler + resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool } // ServerOptions is used to configure behavior of the server. @@ -50,65 +54,102 @@ type ServerOptions struct { // Optional instructions for connected clients. Instructions string // If non-nil, called when "notifications/initialized" is received. - InitializedHandler func(context.Context, *ServerSession, *InitializedParams) + InitializedHandler func(context.Context, *InitializedRequest) // PageSize is the maximum number of items to return in a single page for // list methods (e.g. ListTools). + // + // If zero, defaults to [DefaultPageSize]. PageSize int // If non-nil, called when "notifications/roots/list_changed" is received. - RootsListChangedHandler func(context.Context, *ServerSession, *RootsListChangedParams) + RootsListChangedHandler func(context.Context, *RootsListChangedRequest) // If non-nil, called when "notifications/progress" is received. - ProgressNotificationHandler func(context.Context, *ServerSession, *ProgressNotificationParams) + ProgressNotificationHandler func(context.Context, *ProgressNotificationServerRequest) // If non-nil, called when "completion/complete" is received. - CompletionHandler func(context.Context, *ServerSession, *CompleteParams) (*CompleteResult, error) + CompletionHandler func(context.Context, *CompleteRequest) (*CompleteResult, error) // If non-zero, defines an interval for regular "ping" requests. // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. KeepAlive time.Duration + // Function called when a client session subscribes to a resource. + SubscribeHandler func(context.Context, *SubscribeRequest) error + // Function called when a client session unsubscribes from a resource. + UnsubscribeHandler func(context.Context, *UnsubscribeRequest) error + // If true, advertises the prompts capability during initialization, + // even if no prompts have been registered. + HasPrompts bool + // If true, advertises the resources capability during initialization, + // even if no resources have been registered. + HasResources bool + // If true, advertises the tools capability during initialization, + // even if no tools have been registered. + HasTools bool + + // GetSessionID provides the next session ID to use for an incoming request. + // If nil, a default randomly generated ID will be used. + // + // Session IDs should be globally unique across the scope of the server, + // which may span multiple processes in the case of distributed servers. + // + // As a special case, if GetSessionID returns the empty string, the + // Mcp-Session-Id header will not be set. + GetSessionID func() string } // NewServer creates a new MCP server. The resulting server has no features: -// add features using [Server.AddTools], [Server.AddPrompts] and [Server.AddResources]. +// add features using the various Server.AddXXX methods, and the [AddTool] function. // -// The server can be connected to one or more MCP clients using [Server.Start] -// or [Server.Run]. +// The server can be connected to one or more MCP clients using [Server.Run]. // -// If non-nil, the provided options is used to configure the server. -func NewServer(name, version string, opts *ServerOptions) *Server { - if opts == nil { - opts = new(ServerOptions) +// The first argument must not be nil. +// +// If non-nil, the provided options are used to configure the server. +func NewServer(impl *Implementation, options *ServerOptions) *Server { + if impl == nil { + panic("nil Implementation") + } + var opts ServerOptions + if options != nil { + opts = *options } + options = nil // prevent reuse if opts.PageSize < 0 { panic(fmt.Errorf("invalid page size %d", opts.PageSize)) } if opts.PageSize == 0 { opts.PageSize = DefaultPageSize } + if opts.SubscribeHandler != nil && opts.UnsubscribeHandler == nil { + panic("SubscribeHandler requires UnsubscribeHandler") + } + if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil { + panic("UnsubscribeHandler requires SubscribeHandler") + } + + if opts.GetSessionID == nil { + opts.GetSessionID = randText + } + return &Server{ - name: name, - version: version, - opts: *opts, - prompts: newFeatureSet(func(p *ServerPrompt) string { return p.Prompt.Name }), - tools: newFeatureSet(func(t *ServerTool) string { return t.Tool.Name }), - resources: newFeatureSet(func(r *ServerResource) string { return r.Resource.URI }), - resourceTemplates: newFeatureSet(func(t *ServerResourceTemplate) string { return t.ResourceTemplate.URITemplate }), + impl: impl, + opts: opts, + prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), + tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), + resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }), + resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }), sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession], receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], + resourceSubscriptions: make(map[string]map[*ServerSession]bool), } } -// AddPrompts adds the given prompts to the server, -// replacing any with the same names. -func (s *Server) AddPrompts(prompts ...*ServerPrompt) { - // Only notify if something could change. - if len(prompts) == 0 { - return - } - // Assume there was a change, since add replaces existing roots. - // (It's possible a root was replaced with an identical one, but not worth checking.) +// AddPrompt adds a [Prompt] to the server, or replaces one with the same name. +func (s *Server) AddPrompt(p *Prompt, h PromptHandler) { + // Assume there was a change, since add replaces existing items. + // (It's possible an item was replaced with an identical one, but not worth checking.) s.changeAndNotify( notificationPromptListChanged, &PromptListChangedParams{}, - func() bool { s.prompts.add(prompts...); return true }) + func() bool { s.prompts.add(&serverPrompt{p, h}); return true }) } // RemovePrompts removes the prompts with the given names. @@ -118,56 +159,214 @@ func (s *Server) RemovePrompts(names ...string) { func() bool { return s.prompts.remove(names...) }) } -// AddTools adds the given tools to the server, -// replacing any with the same names. -// The arguments must not be modified after this call. +// AddTool adds a [Tool] to the server, or replaces one with the same name. +// The Tool argument must not be modified after this call. // -// AddTools panics if errors are detected. -func (s *Server) AddTools(tools ...*ServerTool) { - if err := s.addToolsErr(tools...); err != nil { - panic(err) +// The tool's input schema must be non-nil and have the type "object". For a tool +// that takes no input, or one where any input is valid, set [Tool.InputSchema] to +// &jsonschema.Schema{Type: "object"}. +// +// If present, the output schema must also have type "object". +// +// When the handler is invoked as part of a CallTool request, req.Params.Arguments +// will be a json.RawMessage. +// +// Unmarshaling the arguments and validating them against the input schema are the +// caller's responsibility. +// +// Validating the result against the output schema, if any, is the caller's responsibility. +// +// Setting the result's Content, StructuredContent and IsError fields are the caller's +// responsibility. +// +// Most users should use the top-level function [AddTool], which handles all these +// responsibilities. +func (s *Server) AddTool(t *Tool, h ToolHandler) { + if t.InputSchema == nil { + // This prevents the tool author from forgetting to write a schema where + // one should be provided. If we papered over this by supplying the empty + // schema, then every input would be validated and the problem wouldn't be + // discovered until runtime, when the LLM sent bad data. + panic(fmt.Errorf("AddTool %q: missing input schema", t.Name)) + } + if t.InputSchema.Type != "object" { + panic(fmt.Errorf(`AddTool %q: input schema must have type "object"`, t.Name)) } + if t.OutputSchema != nil && t.OutputSchema.Type != "object" { + panic(fmt.Errorf(`AddTool %q: output schema must have type "object"`, t.Name)) + } + st := &serverTool{tool: t, handler: h} + // Assume there was a change, since add replaces existing tools. + // (It's possible a tool was replaced with an identical one, but not worth checking.) + // TODO: Batch these changes by size and time? The typescript SDK doesn't. + // TODO: Surface notify error here? best not, in case we need to batch. + s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{}, + func() bool { s.tools.add(st); return true }) } -// addToolsErr is like [AddTools], but returns an error instead of panicking. -func (s *Server) addToolsErr(tools ...*ServerTool) error { - // Only notify if something could change. - if len(tools) == 0 { - return nil +func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { + tt := *t + + // Special handling for an "any" input: treat as an empty object. + if reflect.TypeFor[In]() == reflect.TypeFor[any]() && t.InputSchema == nil { + tt.InputSchema = &jsonschema.Schema{Type: "object"} + } + + var inputResolved *jsonschema.Resolved + if _, err := setSchema[In](&tt.InputSchema, &inputResolved); err != nil { + return nil, nil, fmt.Errorf("input schema: %w", err) + } + + // Handling for zero values: + // + // If Out is a pointer type and we've derived the output schema from its + // element type, use the zero value of its element type in place of a typed + // nil. + var ( + elemZero any // only non-nil if Out is a pointer type + outputResolved *jsonschema.Resolved + ) + if t.OutputSchema != nil || reflect.TypeFor[Out]() != reflect.TypeFor[any]() { + var err error + elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved) + if err != nil { + return nil, nil, fmt.Errorf("output schema: %v", err) + } } - // Wrap the user's Handlers with rawHandlers that take a json.RawMessage. - for _, st := range tools { - if st.rawHandler == nil { - // This ServerTool was not created with NewServerTool. - if st.Handler == nil { - return fmt.Errorf("AddTools: tool %q has no handler", st.Tool.Name) + + th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + var input json.RawMessage + if req.Params.Arguments != nil { + input = req.Params.Arguments + } + // Validate input and apply defaults. + var err error + input, err = applySchema(input, inputResolved) + if err != nil { + // TODO(#450): should this be considered a tool error? (and similar below) + return nil, fmt.Errorf("%w: validating \"arguments\": %v", jsonrpc2.ErrInvalidParams, err) + } + + // Unmarshal and validate args. + var in In + if input != nil { + if err := json.Unmarshal(input, &in); err != nil { + return nil, fmt.Errorf("%w: %v", jsonrpc2.ErrInvalidParams, err) } - st.rawHandler = newRawHandler(st) - // Resolve the schemas, with no base URI. We don't expect tool schemas to - // refer outside of themselves. - if st.Tool.InputSchema != nil { - r, err := st.Tool.InputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) - if err != nil { - return err - } - st.inputResolved = r + } + + // Call typed handler. + res, out, err := h(ctx, req, in) + // Handle server errors appropriately: + // - If the handler returns a structured error (like jsonrpc2.WireError), return it directly + // - If the handler returns a regular error, wrap it in a CallToolResult with IsError=true + // - This allows tools to distinguish between protocol errors and tool execution errors + if err != nil { + // Check if this is already a structured JSON-RPC error + if wireErr, ok := err.(*jsonrpc2.WireError); ok { + return nil, wireErr } + // For regular errors, embed them in the tool result as per MCP spec + var errRes CallToolResult + errRes.setError(err) + return &errRes, nil + } - // if st.Tool.OutputSchema != nil { - // st.outputResolved, err := st.Tool.OutputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) - // if err != nil { - // return err - // } - // } + if res == nil { + res = &CallToolResult{} } + + // Marshal the output and put the RawMessage in the StructuredContent field. + var outval any = out + if elemZero != nil { + // Avoid typed nil, which will serialize as JSON null. + // Instead, use the zero value of the unpointered type. + var z Out + if any(out) == any(z) { // zero is only non-nil if Out is a pointer type + outval = elemZero + } + } + if outval != nil { + outbytes, err := json.Marshal(outval) + if err != nil { + return nil, fmt.Errorf("marshaling output: %w", err) + } + outJSON := json.RawMessage(outbytes) + // Validate the output JSON, and apply defaults. + // + // We validate against the JSON, rather than the output value, as + // some types may have custom JSON marshalling (issue #447). + outJSON, err = applySchema(outJSON, outputResolved) + if err != nil { + return nil, fmt.Errorf("validating tool output: %w", err) + } + res.StructuredContent = outJSON // avoid a second marshal over the wire + + // If the Content field isn't being used, return the serialized JSON in a + // TextContent block, as the spec suggests: + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content. + if res.Content == nil { + res.Content = []Content{&TextContent{ + Text: string(outJSON), + }} + } + } + return res, nil + } // end of handler + + return &tt, th, nil +} + +// setSchema sets the schema and resolved schema corresponding to the type T. +// +// If sfield is nil, the schema is derived from T. +// +// Pointers are treated equivalently to non-pointers when deriving the schema. +// If an indirection occurred to derive the schema, a non-nil zero value is +// returned to be used in place of the typed nil zero value. +// +// Note that if sfield already holds a schema, zero will be nil even if T is a +// pointer: if the user provided the schema, they may have intentionally +// derived it from the pointer type, and handling of zero values is up to them. +// +// TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we +// should have a jsonschema.Zero(schema) helper? +func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) (zero any, err error) { + if *sfield == nil { + rt := reflect.TypeFor[T]() + if rt.Kind() == reflect.Pointer { + rt = rt.Elem() + zero = reflect.Zero(rt).Interface() + } + // TODO: we should be able to pass nil opts here. + *sfield, err = jsonschema.ForType(rt, &jsonschema.ForOptions{}) } + if err != nil { + return zero, err + } + *rfield, err = (*sfield).Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) + return zero, err +} - // Assume there was a change, since add replaces existing tools. - // (It's possible a tool was replaced with an identical one, but not worth checking.) - // TODO: surface notify error here? - s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{}, - func() bool { s.tools.add(tools...); return true }) - return nil +// AddTool adds a tool and typed tool handler to the server. +// +// If the tool's input schema is nil, it is set to the schema inferred from the +// In type parameter, using [jsonschema.For]. The In type argument must be a +// map or a struct, so that its inferred JSON Schema has type "object". +// +// If the tool's output schema is nil, and the Out type is not 'any', the +// output schema is set to the schema inferred from the Out type argument, +// which also must be a map or struct. +// +// Unlike [Server.AddTool], AddTool does a lot automatically, and forces tools +// to conform to the MCP spec. See [ToolHandlerFor] for a detailed description +// of this automatic behavior. +func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { + tt, hh, err := toolForErr(t, h) + if err != nil { + panic(fmt.Sprintf("AddTool: tool %q: %v", t.Name, err)) + } + s.AddTool(tt, hh) } // RemoveTools removes the tools with the given names. @@ -177,26 +376,15 @@ func (s *Server) RemoveTools(names ...string) { func() bool { return s.tools.remove(names...) }) } -// AddResources adds the given resources to the server. -// If a resource with the same URI already exists, it is replaced. -// AddResources panics if a resource URI is invalid or not absolute (has an empty scheme). -func (s *Server) AddResources(resources ...*ServerResource) { - // Only notify if something could change. - if len(resources) == 0 { - return - } +// AddResource adds a [Resource] to the server, or replaces one with the same URI. +// AddResource panics if the resource URI is invalid or not absolute (has an empty scheme). +func (s *Server) AddResource(r *Resource, h ResourceHandler) { s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, func() bool { - for _, r := range resources { - u, err := url.Parse(r.Resource.URI) - if err != nil { - panic(err) // url.Parse includes the URI in the error - } - if !u.IsAbs() { - panic(fmt.Errorf("URI %s needs a scheme", r.Resource.URI)) - } - s.resources.add(r) + if _, err := url.Parse(r.URI); err != nil { + panic(err) // url.Parse includes the URI in the error } + s.resources.add(&serverResource{r, h}) return true }) } @@ -208,20 +396,17 @@ func (s *Server) RemoveResources(uris ...string) { func() bool { return s.resources.remove(uris...) }) } -// AddResourceTemplates adds the given resource templates to the server. -// If a resource template with the same URI template already exists, it will be replaced. -// AddResourceTemplates panics if a URI template is invalid or not absolute (has an empty scheme). -func (s *Server) AddResourceTemplates(templates ...*ServerResourceTemplate) { - // Only notify if something could change. - if len(templates) == 0 { - return - } +// AddResourceTemplate adds a [ResourceTemplate] to the server, or replaces one with the same URI. +// AddResourceTemplate panics if a URI template is invalid or not absolute (has an empty scheme). +func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) { s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, func() bool { - for _, t := range templates { - // TODO: check template validity. - s.resourceTemplates.add(t) + // Validate the URI template syntax + _, err := uritemplate.New(t.URITemplate) + if err != nil { + panic(fmt.Errorf("URI template %q is invalid: %w", t.URITemplate, err)) } + s.resourceTemplates.add(&serverResourceTemplate{t, h}) return true }) } @@ -233,11 +418,36 @@ func (s *Server) RemoveResourceTemplates(uriTemplates ...string) { func() bool { return s.resourceTemplates.remove(uriTemplates...) }) } -func (s *Server) complete(ctx context.Context, ss *ServerSession, params *CompleteParams) (Result, error) { +func (s *Server) capabilities() *ServerCapabilities { + s.mu.Lock() + defer s.mu.Unlock() + + caps := &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + } + if s.opts.HasTools || s.tools.len() > 0 { + caps.Tools = &ToolCapabilities{ListChanged: true} + } + if s.opts.HasPrompts || s.prompts.len() > 0 { + caps.Prompts = &PromptCapabilities{ListChanged: true} + } + if s.opts.HasResources || s.resources.len() > 0 || s.resourceTemplates.len() > 0 { + caps.Resources = &ResourceCapabilities{ListChanged: true} + if s.opts.SubscribeHandler != nil { + caps.Resources.Subscribe = true + } + } + if s.opts.CompletionHandler != nil { + caps.Completions = &CompletionCapabilities{} + } + return caps +} + +func (s *Server) complete(ctx context.Context, req *CompleteRequest) (*CompleteResult, error) { if s.opts.CompletionHandler == nil { return nil, jsonrpc2.ErrMethodNotFound } - return s.opts.CompletionHandler(ctx, ss, params) + return s.opts.CompletionHandler(ctx, req) } // changeAndNotify is called when a feature is added or removed. @@ -262,86 +472,98 @@ func (s *Server) Sessions() iter.Seq[*ServerSession] { return slices.Values(clients) } -func (s *Server) listPrompts(_ context.Context, _ *ServerSession, params *ListPromptsParams) (*ListPromptsResult, error) { +func (s *Server) listPrompts(_ context.Context, req *ListPromptsRequest) (*ListPromptsResult, error) { s.mu.Lock() defer s.mu.Unlock() - if params == nil { - params = &ListPromptsParams{} + if req.Params == nil { + req.Params = &ListPromptsParams{} } - return paginateList(s.prompts, s.opts.PageSize, params, &ListPromptsResult{}, func(res *ListPromptsResult, prompts []*ServerPrompt) { + return paginateList(s.prompts, s.opts.PageSize, req.Params, &ListPromptsResult{}, func(res *ListPromptsResult, prompts []*serverPrompt) { res.Prompts = []*Prompt{} // avoid JSON null for _, p := range prompts { - res.Prompts = append(res.Prompts, p.Prompt) + res.Prompts = append(res.Prompts, p.prompt) } }) } -func (s *Server) getPrompt(ctx context.Context, cc *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { +func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetPromptResult, error) { s.mu.Lock() - prompt, ok := s.prompts.get(params.Name) + prompt, ok := s.prompts.get(req.Params.Name) s.mu.Unlock() if !ok { - // TODO: surface the error code over the wire, instead of flattening it into the string. - return nil, fmt.Errorf("%s: unknown prompt %q", jsonrpc2.ErrInvalidParams, params.Name) + // Return a proper JSON-RPC error with the correct error code + return nil, &jsonrpc2.WireError{ + Code: CodeInvalidParams, + Message: fmt.Sprintf("unknown prompt %q", req.Params.Name), + } } - return prompt.Handler(ctx, cc, params) + return prompt.handler(ctx, req) } -func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListToolsParams) (*ListToolsResult, error) { +func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListToolsResult, error) { s.mu.Lock() defer s.mu.Unlock() - if params == nil { - params = &ListToolsParams{} + if req.Params == nil { + req.Params = &ListToolsParams{} } - return paginateList(s.tools, s.opts.PageSize, params, &ListToolsResult{}, func(res *ListToolsResult, tools []*ServerTool) { + return paginateList(s.tools, s.opts.PageSize, req.Params, &ListToolsResult{}, func(res *ListToolsResult, tools []*serverTool) { res.Tools = []*Tool{} // avoid JSON null for _, t := range tools { - res.Tools = append(res.Tools, t.Tool) + res.Tools = append(res.Tools, t.tool) } }) } -func (s *Server) callTool(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) { +func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { s.mu.Lock() - tool, ok := s.tools.get(params.Name) + st, ok := s.tools.get(req.Params.Name) s.mu.Unlock() if !ok { - return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, params.Name) + return nil, &jsonrpc2.WireError{ + Code: CodeInvalidParams, + Message: fmt.Sprintf("unknown tool %q", req.Params.Name), + } + } + res, err := st.handler(ctx, req) + if err == nil && res != nil && res.Content == nil { + res2 := *res + res2.Content = []Content{} // avoid "null" + res = &res2 } - return tool.rawHandler(ctx, cc, params) + return res, err } -func (s *Server) listResources(_ context.Context, _ *ServerSession, params *ListResourcesParams) (*ListResourcesResult, error) { +func (s *Server) listResources(_ context.Context, req *ListResourcesRequest) (*ListResourcesResult, error) { s.mu.Lock() defer s.mu.Unlock() - if params == nil { - params = &ListResourcesParams{} + if req.Params == nil { + req.Params = &ListResourcesParams{} } - return paginateList(s.resources, s.opts.PageSize, params, &ListResourcesResult{}, func(res *ListResourcesResult, resources []*ServerResource) { + return paginateList(s.resources, s.opts.PageSize, req.Params, &ListResourcesResult{}, func(res *ListResourcesResult, resources []*serverResource) { res.Resources = []*Resource{} // avoid JSON null for _, r := range resources { - res.Resources = append(res.Resources, r.Resource) + res.Resources = append(res.Resources, r.resource) } }) } -func (s *Server) listResourceTemplates(_ context.Context, _ *ServerSession, params *ListResourceTemplatesParams) (*ListResourceTemplatesResult, error) { +func (s *Server) listResourceTemplates(_ context.Context, req *ListResourceTemplatesRequest) (*ListResourceTemplatesResult, error) { s.mu.Lock() defer s.mu.Unlock() - if params == nil { - params = &ListResourceTemplatesParams{} + if req.Params == nil { + req.Params = &ListResourceTemplatesParams{} } - return paginateList(s.resourceTemplates, s.opts.PageSize, params, &ListResourceTemplatesResult{}, - func(res *ListResourceTemplatesResult, rts []*ServerResourceTemplate) { + return paginateList(s.resourceTemplates, s.opts.PageSize, req.Params, &ListResourceTemplatesResult{}, + func(res *ListResourceTemplatesResult, rts []*serverResourceTemplate) { res.ResourceTemplates = []*ResourceTemplate{} // avoid JSON null for _, rt := range rts { - res.ResourceTemplates = append(res.ResourceTemplates, rt.ResourceTemplate) + res.ResourceTemplates = append(res.ResourceTemplates, rt.resourceTemplate) } }) } -func (s *Server) readResource(ctx context.Context, ss *ServerSession, params *ReadResourceParams) (*ReadResourceResult, error) { - uri := params.URI +func (s *Server) readResource(ctx context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { + uri := req.Params.URI // Look up the resource URI in the lists of resources and resource templates. // This is a security check as well as an information lookup. handler, mimeType, ok := s.lookupResourceHandler(uri) @@ -350,7 +572,7 @@ func (s *Server) readResource(ctx context.Context, ss *ServerSession, params *Re // Treat an unregistered resource the same as a registered one that couldn't be found. return nil, ResourceNotFoundError(uri) } - res, err := handler(ctx, ss, params) + res, err := handler(ctx, req) if err != nil { return nil, err } @@ -376,12 +598,12 @@ func (s *Server) lookupResourceHandler(uri string) (ResourceHandler, string, boo defer s.mu.Unlock() // Try resources first. if r, ok := s.resources.get(uri); ok { - return r.Handler, r.Resource.MIMEType, true + return r.handler, r.resource.MIMEType, true } // Look for matching template. for rt := range s.resourceTemplates.all() { if rt.Matches(uri) { - return rt.Handler, rt.ResourceTemplate.MIMEType, true + return rt.handler, rt.resourceTemplate.MIMEType, true } } return nil, "", false @@ -394,26 +616,22 @@ func (s *Server) lookupResourceHandler(uri string) (ResourceHandler, string, boo // The dir argument should be a filesystem path. It need not be absolute, but // that is recommended to avoid a dependency on the current working directory (the // check against client roots is done with an absolute path). If dir is not absolute -// and the current working directory is unavailable, FileResourceHandler panics. +// and the current working directory is unavailable, fileResourceHandler panics. // // Lexical path traversal attacks, where the path has ".." elements that escape dir, // are always caught. Go 1.24 and above also protects against symlink-based attacks, // where symlinks under dir lead out of the tree. -func (s *Server) fileResourceHandler(dir string) ResourceHandler { - return fileResourceHandler(dir) -} - func fileResourceHandler(dir string) ResourceHandler { // Convert dir to an absolute path. dirFilepath, err := filepath.Abs(dir) if err != nil { panic(err) } - return func(ctx context.Context, ss *ServerSession, params *ReadResourceParams) (_ *ReadResourceResult, err error) { - defer util.Wrapf(&err, "reading resource %s", params.URI) + return func(ctx context.Context, req *ReadResourceRequest) (_ *ReadResourceResult, err error) { + defer util.Wrapf(&err, "reading resource %s", req.Params.URI) - // TODO: use a memoizing API here. - rootRes, err := ss.ListRoots(ctx, nil) + // TODO(#25): use a memoizing API here. + rootRes, err := req.Session.ListRoots(ctx, nil) if err != nil { return nil, fmt.Errorf("listing roots: %w", err) } @@ -421,32 +639,109 @@ func fileResourceHandler(dir string) ResourceHandler { if err != nil { return nil, err } - data, err := readFileResource(params.URI, dirFilepath, roots) + data, err := readFileResource(req.Params.URI, dirFilepath, roots) if err != nil { return nil, err } // TODO(jba): figure out mime type. Omit for now: Server.readResource will fill it in. return &ReadResourceResult{Contents: []*ResourceContents{ - {URI: params.URI, Blob: data}, + {URI: req.Params.URI, Blob: data}, }}, nil } } +// ResourceUpdated sends a notification to all clients that have subscribed to the +// resource specified in params. This method is the primary way for a +// server author to signal that a resource has changed. +func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNotificationParams) error { + s.mu.Lock() + subscribedSessions := s.resourceSubscriptions[params.URI] + sessions := slices.Collect(maps.Keys(subscribedSessions)) + s.mu.Unlock() + notifySessions(sessions, notificationResourceUpdated, params) + return nil +} + +func (s *Server) subscribe(ctx context.Context, req *SubscribeRequest) (*emptyResult, error) { + if s.opts.SubscribeHandler == nil { + return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound) + } + if err := s.opts.SubscribeHandler(ctx, req); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + if s.resourceSubscriptions[req.Params.URI] == nil { + s.resourceSubscriptions[req.Params.URI] = make(map[*ServerSession]bool) + } + s.resourceSubscriptions[req.Params.URI][req.Session] = true + + return &emptyResult{}, nil +} + +func (s *Server) unsubscribe(ctx context.Context, req *UnsubscribeRequest) (*emptyResult, error) { + if s.opts.UnsubscribeHandler == nil { + return nil, jsonrpc2.ErrMethodNotFound + } + + if err := s.opts.UnsubscribeHandler(ctx, req); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + if subscribedSessions, ok := s.resourceSubscriptions[req.Params.URI]; ok { + delete(subscribedSessions, req.Session) + if len(subscribedSessions) == 0 { + delete(s.resourceSubscriptions, req.Params.URI) + } + } + + return &emptyResult{}, nil +} + // Run runs the server over the given transport, which must be persistent. // -// Run blocks until the client terminates the connection. +// Run blocks until the client terminates the connection or the provided +// context is cancelled. If the context is cancelled, Run closes the connection. +// +// If tools have been added to the server before this call, then the server will +// advertise the capability for tools, including the ability to send list-changed notifications. +// If no tools have been added, the server will not have the tool capability. +// The same goes for other features like prompts and resources. +// +// Run is a convenience for servers that handle a single session (or one session at a time). +// It need not be called on servers that are used for multiple concurrent connections, +// as with [StreamableHTTPHandler]. func (s *Server) Run(ctx context.Context, t Transport) error { - ss, err := s.Connect(ctx, t) + ss, err := s.Connect(ctx, t, nil) if err != nil { return err } - return ss.Wait() + + ssClosed := make(chan error) + go func() { + ssClosed <- ss.Wait() + }() + + select { + case <-ctx.Done(): + ss.Close() + return ctx.Err() + case err := <-ssClosed: + return err + } } // bind implements the binder[*ServerSession] interface, so that Servers can // be connected using [connect]. -func (s *Server) bind(conn *jsonrpc2.Connection) *ServerSession { - ss := &ServerSession{conn: conn, server: s} +func (s *Server) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *ServerSessionState, onClose func()) *ServerSession { + assert(mcpConn != nil && conn != nil, "nil connection") + ss := &ServerSession{conn: conn, mcpConn: mcpConn, server: s, onClose: onClose} + if state != nil { + ss.state = *state + } s.mu.Lock() s.sessions = append(s.sessions, ss) s.mu.Unlock() @@ -461,6 +756,17 @@ func (s *Server) disconnect(cc *ServerSession) { s.sessions = slices.DeleteFunc(s.sessions, func(cc2 *ServerSession) bool { return cc2 == cc }) + + for _, subscribedSessions := range s.resourceSubscriptions { + delete(subscribedSessions, cc) + } +} + +// ServerSessionOptions configures the server session. +type ServerSessionOptions struct { + State *ServerSessionState + + onClose func() } // Connect connects the MCP server over the given transport and starts handling @@ -469,23 +775,61 @@ func (s *Server) disconnect(cc *ServerSession) { // It returns a connection object that may be used to terminate the connection // (with [Connection.Close]), or await client termination (with // [Connection.Wait]). -func (s *Server) Connect(ctx context.Context, t Transport) (*ServerSession, error) { - return connect(ctx, t, s) +// +// If opts.State is non-nil, it is the initial state for the server. +func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOptions) (*ServerSession, error) { + var state *ServerSessionState + var onClose func() + if opts != nil { + state = opts.State + onClose = opts.onClose + } + return connect(ctx, t, s, state, onClose) } -func (s *Server) callInitializedHandler(ctx context.Context, ss *ServerSession, params *InitializedParams) (Result, error) { - if s.opts.KeepAlive > 0 { - ss.startKeepalive(s.opts.KeepAlive) +// TODO: (nit) move all ServerSession methods below the ServerSession declaration. +func (ss *ServerSession) initialized(ctx context.Context, params *InitializedParams) (Result, error) { + if params == nil { + // Since we use nilness to signal 'initialized' state, we must ensure that + // params are non-nil. + params = new(InitializedParams) + } + var wasInit, wasInitd bool + ss.updateState(func(state *ServerSessionState) { + wasInit = state.InitializeParams != nil + wasInitd = state.InitializedParams != nil + if wasInit && !wasInitd { + state.InitializedParams = params + } + }) + + if !wasInit { + return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize) + } + if wasInitd { + return nil, fmt.Errorf("duplicate %q received", notificationInitialized) } - return callNotificationHandler(ctx, s.opts.InitializedHandler, ss, params) + if ss.server.opts.KeepAlive > 0 { + ss.startKeepalive(ss.server.opts.KeepAlive) + } + if h := ss.server.opts.InitializedHandler; h != nil { + h(ctx, serverRequestFor(ss, params)) + } + return nil, nil } -func (s *Server) callRootsListChangedHandler(ctx context.Context, ss *ServerSession, params *RootsListChangedParams) (Result, error) { - return callNotificationHandler(ctx, s.opts.RootsListChangedHandler, ss, params) +func (s *Server) callRootsListChangedHandler(ctx context.Context, req *RootsListChangedRequest) (Result, error) { + if h := s.opts.RootsListChangedHandler; h != nil { + h(ctx, req) + } + return nil, nil } -func (ss *ServerSession) callProgressNotificationHandler(ctx context.Context, params *ProgressNotificationParams) (Result, error) { - return callNotificationHandler(ctx, ss.server.opts.ProgressNotificationHandler, ss, params) +func (ss *ServerSession) callProgressNotificationHandler(ctx context.Context, p *ProgressNotificationParams) (Result, error) { + if h := ss.server.opts.ProgressNotificationHandler; h != nil { + h(ctx, serverRequestFor(ss, p)) + } + return nil, nil } // NotifyProgress sends a progress notification from the server to the client @@ -493,7 +837,11 @@ func (ss *ServerSession) callProgressNotificationHandler(ctx context.Context, pa // This is typically used to report on the status of a long-running request // that was initiated by the client. func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNotificationParams) error { - return handleNotify(ctx, ss, notificationProgress, params) + return handleNotify(ctx, notificationProgress, newServerRequest(ss, orZero[Params](params))) +} + +func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] { + return &ServerRequest[P]{Session: ss, Params: params} } // A ServerSession is a logical connection from a single MCP client. Its @@ -503,41 +851,97 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot // Call [ServerSession.Close] to close the connection, or await client // termination with [ServerSession.Wait]. type ServerSession struct { - server *Server - conn *jsonrpc2.Connection - mcpConn Connection - mu sync.Mutex - logLevel LoggingLevel - initializeParams *InitializeParams - initialized bool - keepaliveCancel context.CancelFunc + onClose func() + + server *Server + conn *jsonrpc2.Connection + mcpConn Connection + keepaliveCancel context.CancelFunc // TODO: theory around why keepaliveCancel need not be guarded + + mu sync.Mutex + state ServerSessionState +} + +func (ss *ServerSession) updateState(mut func(*ServerSessionState)) { + ss.mu.Lock() + mut(&ss.state) + copy := ss.state + ss.mu.Unlock() + if c, ok := ss.mcpConn.(serverConnection); ok { + c.sessionUpdated(copy) + } } -func (ss *ServerSession) setConn(c Connection) { - ss.mcpConn = c +// hasInitialized reports whether the server has received the initialized +// notification. +// +// TODO(findleyr): use this to prevent change notifications. +func (ss *ServerSession) hasInitialized() bool { + ss.mu.Lock() + defer ss.mu.Unlock() + return ss.state.InitializedParams != nil +} + +// checkInitialized returns a formatted error if the server has not yet +// received the initialized notification. +func (ss *ServerSession) checkInitialized(method string) error { + if !ss.hasInitialized() { + // TODO(rfindley): enable this check. + // Right now is is flaky, because server tests don't await the initialized notification. + // Perhaps requests should simply block until they have received the initialized notification + + // if strings.HasPrefix(method, "notifications/") { + // return fmt.Errorf("must not send %q before %q is received", method, notificationInitialized) + // } else { + // return fmt.Errorf("cannot call %q before %q is received", method, notificationInitialized) + // } + } + return nil } func (ss *ServerSession) ID() string { - if ss.mcpConn == nil { - return "" + if c, ok := ss.mcpConn.(hasSessionID); ok { + return c.SessionID() } - return ss.mcpConn.SessionID() + return "" } // Ping pings the client. func (ss *ServerSession) Ping(ctx context.Context, params *PingParams) error { - _, err := handleSend[*emptyResult](ctx, ss, methodPing, params) + _, err := handleSend[*emptyResult](ctx, methodPing, newServerRequest(ss, orZero[Params](params))) return err } // ListRoots lists the client roots. func (ss *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) (*ListRootsResult, error) { - return handleSend[*ListRootsResult](ctx, ss, methodListRoots, orZero[Params](params)) + if err := ss.checkInitialized(methodListRoots); err != nil { + return nil, err + } + return handleSend[*ListRootsResult](ctx, methodListRoots, newServerRequest(ss, orZero[Params](params))) } // CreateMessage sends a sampling request to the client. func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessageParams) (*CreateMessageResult, error) { - return handleSend[*CreateMessageResult](ctx, ss, methodCreateMessage, orZero[Params](params)) + if err := ss.checkInitialized(methodCreateMessage); err != nil { + return nil, err + } + if params == nil { + params = &CreateMessageParams{Messages: []*SamplingMessage{}} + } + if params.Messages == nil { + p2 := *params + p2.Messages = []*SamplingMessage{} // avoid JSON "null" + params = &p2 + } + return handleSend[*CreateMessageResult](ctx, methodCreateMessage, newServerRequest(ss, orZero[Params](params))) +} + +// Elicit sends an elicitation request to the client asking for user input. +func (ss *ServerSession) Elicit(ctx context.Context, params *ElicitParams) (*ElicitResult, error) { + if err := ss.checkInitialized(methodElicit); err != nil { + return nil, err + } + return handleSend[*ElicitResult](ctx, methodElicit, newServerRequest(ss, orZero[Params](params))) } // Log sends a log message to the client. @@ -545,7 +949,7 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag // is below that of the last SetLevel. func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) error { ss.mu.Lock() - logLevel := ss.logLevel + logLevel := ss.state.LogLevel ss.mu.Unlock() if logLevel == "" { // The spec is unclear, but seems to imply that no log messages are sent until the client @@ -556,7 +960,7 @@ func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) if compareLevels(params.Level, logLevel) < 0 { return nil } - return handleNotify(ctx, ss, notificationLoggingMessage, params) + return handleNotify(ctx, notificationLoggingMessage, newServerRequest(ss, orZero[Params](params))) } // AddSendingMiddleware wraps the current sending method handler using the provided @@ -568,7 +972,7 @@ func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) // // Sending middleware is called when a request is sent. It is useful for tasks // such as tracing, metrics, and adding progress tokens. -func (s *Server) AddSendingMiddleware(middleware ...Middleware[*ServerSession]) { +func (s *Server) AddSendingMiddleware(middleware ...Middleware) { s.mu.Lock() defer s.mu.Unlock() addMiddleware(&s.sendingMethodHandler_, middleware) @@ -583,115 +987,107 @@ func (s *Server) AddSendingMiddleware(middleware ...Middleware[*ServerSession]) // // Receiving middleware is called when a request is received. It is useful for tasks // such as authentication, request logging and metrics. -func (s *Server) AddReceivingMiddleware(middleware ...Middleware[*ServerSession]) { +func (s *Server) AddReceivingMiddleware(middleware ...Middleware) { s.mu.Lock() defer s.mu.Unlock() addMiddleware(&s.receivingMethodHandler_, middleware) } // serverMethodInfos maps from the RPC method name to serverMethodInfos. +// +// The 'allowMissingParams' values are extracted from the protocol schema. +// TODO(rfindley): actually load and validate the protocol schema, rather than +// curating these method flags. var serverMethodInfos = map[string]methodInfo{ - methodComplete: newMethodInfo(serverMethod((*Server).complete)), - methodInitialize: newMethodInfo(sessionMethod((*ServerSession).initialize)), - methodPing: newMethodInfo(sessionMethod((*ServerSession).ping)), - methodListPrompts: newMethodInfo(serverMethod((*Server).listPrompts)), - methodGetPrompt: newMethodInfo(serverMethod((*Server).getPrompt)), - methodListTools: newMethodInfo(serverMethod((*Server).listTools)), - methodCallTool: newMethodInfo(serverMethod((*Server).callTool)), - methodListResources: newMethodInfo(serverMethod((*Server).listResources)), - methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates)), - methodReadResource: newMethodInfo(serverMethod((*Server).readResource)), - methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel)), - notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler)), - notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler)), - notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler)), + methodComplete: newServerMethodInfo(serverMethod((*Server).complete), 0), + methodInitialize: newServerMethodInfo(serverSessionMethod((*ServerSession).initialize), 0), + methodPing: newServerMethodInfo(serverSessionMethod((*ServerSession).ping), missingParamsOK), + methodListPrompts: newServerMethodInfo(serverMethod((*Server).listPrompts), missingParamsOK), + methodGetPrompt: newServerMethodInfo(serverMethod((*Server).getPrompt), 0), + methodListTools: newServerMethodInfo(serverMethod((*Server).listTools), missingParamsOK), + methodCallTool: newServerMethodInfo(serverMethod((*Server).callTool), 0), + methodListResources: newServerMethodInfo(serverMethod((*Server).listResources), missingParamsOK), + methodListResourceTemplates: newServerMethodInfo(serverMethod((*Server).listResourceTemplates), missingParamsOK), + methodReadResource: newServerMethodInfo(serverMethod((*Server).readResource), 0), + methodSetLevel: newServerMethodInfo(serverSessionMethod((*ServerSession).setLevel), 0), + methodSubscribe: newServerMethodInfo(serverMethod((*Server).subscribe), 0), + methodUnsubscribe: newServerMethodInfo(serverMethod((*Server).unsubscribe), 0), + notificationCancelled: newServerMethodInfo(serverSessionMethod((*ServerSession).cancel), notification|missingParamsOK), + notificationInitialized: newServerMethodInfo(serverSessionMethod((*ServerSession).initialized), notification|missingParamsOK), + notificationRootsListChanged: newServerMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK), + notificationProgress: newServerMethodInfo(serverSessionMethod((*ServerSession).callProgressNotificationHandler), notification), } func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos } func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { return serverMethodInfos } -func (ss *ServerSession) sendingMethodHandler() methodHandler { - ss.server.mu.Lock() - defer ss.server.mu.Unlock() - return ss.server.sendingMethodHandler_ +func (ss *ServerSession) sendingMethodHandler() MethodHandler { + s := ss.server + s.mu.Lock() + defer s.mu.Unlock() + return s.sendingMethodHandler_ } -func (ss *ServerSession) receivingMethodHandler() methodHandler { - ss.server.mu.Lock() - defer ss.server.mu.Unlock() - return ss.server.receivingMethodHandler_ +func (ss *ServerSession) receivingMethodHandler() MethodHandler { + s := ss.server + s.mu.Lock() + defer s.mu.Unlock() + return s.receivingMethodHandler_ } // getConn implements [session.getConn]. func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn } // handle invokes the method described by the given JSON RPC request. -func (ss *ServerSession) handle(ctx context.Context, req *JSONRPCRequest) (any, error) { +func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { ss.mu.Lock() - initialized := ss.initialized + initialized := ss.state.InitializeParams != nil ss.mu.Unlock() + // From the spec: // "The client SHOULD NOT send requests other than pings before the server // has responded to the initialize request." switch req.Method { - case "initialize", "ping": + case methodInitialize, methodPing, notificationInitialized: default: if !initialized { return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) } } + + // modelcontextprotocol/go-sdk#26: handle calls asynchronously, and + // notifications synchronously, except for 'initialize' which shouldn't be + // asynchronous to other + if req.IsCall() && req.Method != methodInitialize { + jsonrpc2.Async(ctx) + } + // For the streamable transport, we need the request ID to correlate // server->client calls and notifications to the incoming request from which - // they originated. See [idContext] for details. + // they originated. See [idContextKey] for details. ctx = context.WithValue(ctx, idContextKey{}, req.ID) return handleReceive(ctx, ss, req) } -func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) { - ss.mu.Lock() - ss.initializeParams = params - ss.mu.Unlock() +func (ss *ServerSession) InitializeParams() *InitializeParams { return ss.state.InitializeParams } - // Mark the connection as initialized when this method exits. - // TODO: Technically, the server should not be considered initialized until it has - // *responded*, but we don't have adequate visibility into the jsonrpc2 - // connection to implement that easily. In any case, once we've initialized - // here, we can handle requests. - defer func() { - ss.mu.Lock() - ss.initialized = true - ss.mu.Unlock() - }() - - version := "2025-03-26" // preferred version - switch v := params.ProtocolVersion; v { - case "2024-11-05", "2025-03-26": - version = v +func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) { + if params == nil { + return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) } + ss.updateState(func(state *ServerSessionState) { + state.InitializeParams = params + }) + s := ss.server return &InitializeResult{ // TODO(rfindley): alter behavior when falling back to an older version: // reject unsupported features. - ProtocolVersion: version, - Capabilities: &serverCapabilities{ - Completions: &completionCapabilities{}, - Prompts: &promptCapabilities{ - ListChanged: true, - }, - Tools: &toolCapabilities{ - ListChanged: true, - }, - Resources: &resourceCapabilities{ - ListChanged: true, - }, - Logging: &loggingCapabilities{}, - }, - Instructions: ss.server.opts.Instructions, - ServerInfo: &implementation{ - Name: ss.server.name, - Version: ss.server.version, - }, + ProtocolVersion: negotiatedVersion(params.ProtocolVersion), + Capabilities: s.capabilities(), + Instructions: s.opts.Instructions, + ServerInfo: s.impl, }, nil } @@ -699,10 +1095,19 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error return &emptyResult{}, nil } -func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*emptyResult, error) { - ss.mu.Lock() - defer ss.mu.Unlock() - ss.logLevel = params.Level +// cancel is a placeholder: cancellation is handled the jsonrpc2 package. +// +// It should never be invoked in practice because cancellation is preempted, +// but having its signature here facilitates the construction of methodInfo +// that can be used to validate incoming cancellation notifications. +func (ss *ServerSession) cancel(context.Context, *CancelledParams) (Result, error) { + return nil, nil +} + +func (ss *ServerSession) setLevel(_ context.Context, params *SetLoggingLevelParams) (*emptyResult, error) { + ss.updateState(func(state *ServerSessionState) { + state.LogLevel = params.Level + }) return &emptyResult{}, nil } @@ -718,7 +1123,13 @@ func (ss *ServerSession) Close() error { // Close is idempotent and conn.Close() handles concurrent calls correctly ss.keepaliveCancel() } - return ss.conn.Close() + err := ss.conn.Close() + + if ss.onClose != nil { + ss.onClose() + } + + return err } // Wait waits for the connection to be closed by the client. diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index 9e982374..db04920b 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -8,67 +8,213 @@ import ( "context" "fmt" "log" + "log/slog" + "sync/atomic" "github.com/modelcontextprotocol/go-sdk/mcp" ) -type SayHiParams struct { - Name string `json:"name"` -} +// !+prompts + +func Example_prompts() { + ctx := context.Background() -func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[SayHiParams]) (*mcp.CallToolResultFor[any], error) { - return &mcp.CallToolResultFor[any]{ - Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + params.Arguments.Name}, + promptHandler := func(ctx context.Context, req *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return &mcp.GetPromptResult{ + Description: "Hi prompt", + Messages: []*mcp.PromptMessage{ + { + Role: "user", + Content: &mcp.TextContent{Text: "Say hi to " + req.Params.Arguments["name"]}, + }, + }, + }, nil + } + + // Create a server with a single prompt. + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + prompt := &mcp.Prompt{ + Name: "greet", + Arguments: []*mcp.PromptArgument{ + { + Name: "name", + Description: "the name of the person to greet", + Required: true, + }, }, - }, nil + } + s.AddPrompt(prompt, promptHandler) + + // Create a client. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + + // Connect the server and client. + t1, t2 := mcp.NewInMemoryTransports() + if _, err := s.Connect(ctx, t1, nil); err != nil { + log.Fatal(err) + } + cs, err := c.Connect(ctx, t2, nil) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + // List the prompts. + for p, err := range cs.Prompts(ctx, nil) { + if err != nil { + log.Fatal(err) + } + fmt.Println(p.Name) + } + + // Get the prompt. + res, err := cs.GetPrompt(ctx, &mcp.GetPromptParams{ + Name: "greet", + Arguments: map[string]string{"name": "Pat"}, + }) + if err != nil { + log.Fatal(err) + } + for _, msg := range res.Messages { + fmt.Println(msg.Role, msg.Content.(*mcp.TextContent).Text) + } + // Output: + // greet + // user Say hi to Pat } -func ExampleServer() { +// !-prompts + +// !+logging + +func Example_logging() { ctx := context.Background() - clientTransport, serverTransport := mcp.NewInMemoryTransports() - server := mcp.NewServer("greeter", "v0.0.1", nil) - server.AddTools(mcp.NewServerTool("greet", "say hi", SayHi)) + // Create a server. + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) - serverSession, err := server.Connect(ctx, serverTransport) + // Create a client that displays log messages. + done := make(chan struct{}) // solely for the example + var nmsgs atomic.Int32 + c := mcp.NewClient( + &mcp.Implementation{Name: "client", Version: "v0.0.1"}, + &mcp.ClientOptions{ + LoggingMessageHandler: func(_ context.Context, r *mcp.LoggingMessageRequest) { + m := r.Params.Data.(map[string]any) + fmt.Println(m["msg"], m["value"]) + if nmsgs.Add(1) == 2 { // number depends on logger calls below + close(done) + } + }, + }) + + // Connect the server and client. + t1, t2 := mcp.NewInMemoryTransports() + ss, err := s.Connect(ctx, t1, nil) if err != nil { log.Fatal(err) } - - client := mcp.NewClient("client", "v0.0.1", nil) - clientSession, err := client.Connect(ctx, clientTransport) + defer ss.Close() + cs, err := c.Connect(ctx, t2, nil) if err != nil { log.Fatal(err) } + defer cs.Close() - res, err := clientSession.CallTool(ctx, &mcp.CallToolParams{ - Name: "greet", - Arguments: map[string]any{"name": "user"}, - }) - if err != nil { + // Set the minimum log level to "info". + if err := cs.SetLoggingLevel(ctx, &mcp.SetLoggingLevelParams{Level: "info"}); err != nil { log.Fatal(err) } - fmt.Println(res.Content[0].(*mcp.TextContent).Text) - clientSession.Close() - serverSession.Wait() + // Get a slog.Logger for the server session. + logger := slog.New(mcp.NewLoggingHandler(ss, nil)) + + // Log some things. + logger.Info("info shows up", "value", 1) + logger.Debug("debug doesn't show up", "value", 2) + logger.Warn("warn shows up", "value", 3) - // Output: Hi user + // Wait for them to arrive on the client. + // In a real application, the log messages would appear asynchronously + // while other work was happening. + <-done + + // Output: + // info shows up 1 + // warn shows up 3 } -// createSessions creates and connects an in-memory client and server session for testing purposes. -func createSessions(ctx context.Context) (*mcp.ClientSession, *mcp.ServerSession, *mcp.Server) { - server := mcp.NewServer("server", "v0.0.1", nil) - client := mcp.NewClient("client", "v0.0.1", nil) - serverTransport, clientTransport := mcp.NewInMemoryTransports() - serverSession, err := server.Connect(ctx, serverTransport) - if err != nil { +// !-logging + +// !+resources +func Example_resources() { + ctx := context.Background() + + resources := map[string]string{ + "file:///a": "a", + "file:///dir/x": "x", + "file:///dir/y": "y", + } + + handler := func(_ context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + uri := req.Params.URI + c, ok := resources[uri] + if !ok { + return nil, mcp.ResourceNotFoundError(uri) + } + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{URI: uri, Text: c}}, + }, nil + } + + // Create a server with a single resource. + s := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + s.AddResource(&mcp.Resource{URI: "file:///a"}, handler) + s.AddResourceTemplate(&mcp.ResourceTemplate{URITemplate: "file:///dir/{f}"}, handler) + + // Create a client. + c := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + + // Connect the server and client. + t1, t2 := mcp.NewInMemoryTransports() + if _, err := s.Connect(ctx, t1, nil); err != nil { log.Fatal(err) } - clientSession, err := client.Connect(ctx, clientTransport) + cs, err := c.Connect(ctx, t2, nil) if err != nil { log.Fatal(err) } - return clientSession, serverSession, server + defer cs.Close() + + // List resources and resource templates. + for r, err := range cs.Resources(ctx, nil) { + if err != nil { + log.Fatal(err) + } + fmt.Println(r.URI) + } + for r, err := range cs.ResourceTemplates(ctx, nil) { + if err != nil { + log.Fatal(err) + } + fmt.Println(r.URITemplate) + } + + // Read resources. + for _, path := range []string{"a", "dir/x", "b"} { + res, err := cs.ReadResource(ctx, &mcp.ReadResourceParams{URI: "file:///" + path}) + if err != nil { + fmt.Println(err) + } else { + fmt.Println(res.Contents[0].Text) + } + } + // Output: + // file:///a + // file:///dir/{f} + // a + // x + // calling "resources/read": Resource not found } + +// !-resources diff --git a/mcp/server_test.go b/mcp/server_test.go index 19701f39..249ef90b 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -5,11 +5,16 @@ package mcp import ( + "context" + "encoding/json" "log" "slices" + "strings" "testing" + "time" "github.com/google/go-cmp/cmp" + "github.com/google/jsonschema-go/jsonschema" ) type testItem struct { @@ -227,3 +232,359 @@ func TestServerPaginateVariousPageSizes(t *testing.T) { } } } + +func TestServerCapabilities(t *testing.T) { + tool := &Tool{Name: "t", InputSchema: &jsonschema.Schema{Type: "object"}} + testCases := []struct { + name string + configureServer func(s *Server) + serverOpts ServerOptions + wantCapabilities *ServerCapabilities + }{ + { + name: "No capabilities", + configureServer: func(s *Server) {}, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + }, + }, + { + name: "With prompts", + configureServer: func(s *Server) { + s.AddPrompt(&Prompt{Name: "p"}, nil) + }, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Prompts: &PromptCapabilities{ListChanged: true}, + }, + }, + { + name: "With resources", + configureServer: func(s *Server) { + s.AddResource(&Resource{URI: "file:///r"}, nil) + }, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Resources: &ResourceCapabilities{ListChanged: true}, + }, + }, + { + name: "With resource templates", + configureServer: func(s *Server) { + s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) + }, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Resources: &ResourceCapabilities{ListChanged: true}, + }, + }, + { + name: "With resource subscriptions", + configureServer: func(s *Server) { + s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) + }, + serverOpts: ServerOptions{ + SubscribeHandler: func(context.Context, *SubscribeRequest) error { + return nil + }, + UnsubscribeHandler: func(context.Context, *UnsubscribeRequest) error { + return nil + }, + }, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Resources: &ResourceCapabilities{ListChanged: true, Subscribe: true}, + }, + }, + { + name: "With tools", + configureServer: func(s *Server) { + s.AddTool(tool, nil) + }, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, + }, + }, + { + name: "With completions", + configureServer: func(s *Server) {}, + serverOpts: ServerOptions{ + CompletionHandler: func(context.Context, *CompleteRequest) (*CompleteResult, error) { + return nil, nil + }, + }, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Completions: &CompletionCapabilities{}, + }, + }, + { + name: "With all capabilities", + configureServer: func(s *Server) { + s.AddPrompt(&Prompt{Name: "p"}, nil) + s.AddResource(&Resource{URI: "file:///r"}, nil) + s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) + s.AddTool(tool, nil) + }, + serverOpts: ServerOptions{ + SubscribeHandler: func(context.Context, *SubscribeRequest) error { + return nil + }, + UnsubscribeHandler: func(context.Context, *UnsubscribeRequest) error { + return nil + }, + CompletionHandler: func(context.Context, *CompleteRequest) (*CompleteResult, error) { + return nil, nil + }, + }, + wantCapabilities: &ServerCapabilities{ + Completions: &CompletionCapabilities{}, + Logging: &LoggingCapabilities{}, + Prompts: &PromptCapabilities{ListChanged: true}, + Resources: &ResourceCapabilities{ListChanged: true, Subscribe: true}, + Tools: &ToolCapabilities{ListChanged: true}, + }, + }, + { + name: "With initial capabilities", + configureServer: func(s *Server) {}, + serverOpts: ServerOptions{ + HasPrompts: true, + HasResources: true, + HasTools: true, + }, + wantCapabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Prompts: &PromptCapabilities{ListChanged: true}, + Resources: &ResourceCapabilities{ListChanged: true}, + Tools: &ToolCapabilities{ListChanged: true}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + server := NewServer(testImpl, &tc.serverOpts) + tc.configureServer(server) + gotCapabilities := server.capabilities() + if diff := cmp.Diff(tc.wantCapabilities, gotCapabilities); diff != "" { + t.Errorf("capabilities() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestServerAddResourceTemplate(t *testing.T) { + tests := []struct { + name string + template string + expectPanic bool + }{ + {"ValidFileTemplate", "file:///{a}/{b}", false}, + {"ValidCustomScheme", "myproto:///{a}", false}, + {"EmptyVariable", "file:///{}/{b}", true}, + {"UnclosedVariable", "file:///{a", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rt := ResourceTemplate{URITemplate: tt.template} + + defer func() { + if r := recover(); r != nil { + if !tt.expectPanic { + t.Errorf("%s: unexpected panic: %v", tt.name, r) + } + } else { + if tt.expectPanic { + t.Errorf("%s: expected panic but did not panic", tt.name) + } + } + }() + + s := NewServer(testImpl, nil) + s.AddResourceTemplate(&rt, nil) + }) + } +} + +// TestServerSessionkeepaliveCancelOverwritten is to verify that `ServerSession.keepaliveCancel` is assigned exactly once, +// ensuring that only a single goroutine is responsible for the session's keepalive ping mechanism. +func TestServerSessionkeepaliveCancelOverwritten(t *testing.T) { + // Set KeepAlive to a long duration to ensure the keepalive + // goroutine stays alive for the duration of the test without actually sending + // ping requests, since we don't have a real client connection established. + server := NewServer(testImpl, &ServerOptions{KeepAlive: 5 * time.Second}) + ss := &ServerSession{server: server} + + // 1. Initialize the session. + _, err := ss.initialize(context.Background(), &InitializeParams{}) + if err != nil { + t.Fatalf("ServerSession initialize failed: %v", err) + } + + // 2. Call 'initialized' for the first time. This should start the keepalive mechanism. + _, err = ss.initialized(context.Background(), &InitializedParams{}) + if err != nil { + t.Fatalf("First initialized call failed: %v", err) + } + if ss.keepaliveCancel == nil { + t.Fatalf("expected ServerSession.keepaliveCancel to be set after the first call of initialized") + } + + // Save the cancel function and use defer to ensure resources are cleaned up. + firstCancel := ss.keepaliveCancel + defer firstCancel() + + // 3. Manually set the field to nil. + // Do this to facilitate the test's core assertion. The goal is to verify that + // 'ss.keepaliveCancel' is not assigned a second time. By setting it to nil, + // we can easily check after the next call if a new keepalive goroutine was started. + ss.keepaliveCancel = nil + + // 4. Call 'initialized' for the second time. This should return an error. + _, err = ss.initialized(context.Background(), &InitializedParams{}) + if err == nil { + t.Fatalf("Expected 'duplicate initialized received' error on second call, got nil") + } + + // 5. Re-check the field to ensure it remains nil. + // Since 'initialized' correctly returned an error and did not call + // 'startKeepalive', the field should remain unchanged. + if ss.keepaliveCancel != nil { + t.Fatal("expected ServerSession.keepaliveCancel to be nil after we manually niled it and re-initialized") + } +} + +// panicks reports whether f() panics. +func panics(f func()) (b bool) { + defer func() { + b = recover() != nil + }() + f() + return false +} + +func TestAddTool(t *testing.T) { + // AddTool should panic if In or Out are not JSON objects. + s := NewServer(testImpl, nil) + if !panics(func() { + AddTool(s, &Tool{Name: "T1"}, func(context.Context, *CallToolRequest, string) (*CallToolResult, any, error) { return nil, nil, nil }) + }) { + t.Error("bad In: expected panic") + } + if panics(func() { + AddTool(s, &Tool{Name: "T2"}, func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) { + return nil, nil, nil + }) + }) { + t.Error("good In: expected no panic") + } + if !panics(func() { + AddTool(s, &Tool{Name: "T2"}, func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, int, error) { + return nil, 0, nil + }) + }) { + t.Error("bad Out: expected panic") + } +} + +type schema = jsonschema.Schema + +func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut *schema, wantErrContaining string) { + t.Helper() + th := func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) { + return nil, out, nil + } + gott, goth, err := toolForErr(tool, th) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(wantIn, gott.InputSchema); diff != "" { + t.Errorf("input: mismatch (-want, +got):\n%s", diff) + } + if diff := cmp.Diff(wantOut, gott.OutputSchema); diff != "" { + t.Errorf("output: mismatch (-want, +got):\n%s", diff) + } + ctr := &CallToolRequest{ + Params: &CallToolParamsRaw{ + Arguments: json.RawMessage(in), + }, + } + result, err := goth(context.Background(), ctr) + if wantErrContaining != "" { + if err == nil { + t.Errorf("got nil error, want error containing %q", wantErrContaining) + } else { + if !strings.Contains(err.Error(), wantErrContaining) { + t.Errorf("got error %q, want containing %q", err, wantErrContaining) + } + } + } else if err != nil { + t.Errorf("got error %v, want no error", err) + } + + if gott.OutputSchema != nil && err == nil && !result.IsError { + // Check that structured content matches exactly. + unstructured := result.Content[0].(*TextContent).Text + structured := string(result.StructuredContent.(json.RawMessage)) + if diff := cmp.Diff(unstructured, structured); diff != "" { + t.Errorf("Unstructured content does not match structured content exactly (-unstructured +structured):\n%s", diff) + } + } +} + +// TODO: move this to tool_test.go +func TestToolForSchemas(t *testing.T) { + // Validate that toolForErr handles schemas properly. + type in struct { + P int `json:"p,omitempty"` + } + type out struct { + B bool `json:"b,omitempty"` + } + + var ( + falseSchema = &schema{Not: &schema{}} + inSchema = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"p": {Type: "integer"}}} + inSchema2 = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"p": {Type: "string"}}} + outSchema = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"b": {Type: "boolean"}}} + outSchema2 = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"b": {Type: "integer"}}} + ) + + // Infer both schemas. + testToolForSchema[in](t, &Tool{}, `{"p":3}`, out{true}, inSchema, outSchema, "") + // Validate the input schema: expect an error if it's wrong. + // We can't test that the output schema is validated, because it's typed. + testToolForSchema[in](t, &Tool{}, `{"p":"x"}`, out{true}, inSchema, outSchema, `want "integer"`) + // Ignore type any for output. + testToolForSchema[in, any](t, &Tool{}, `{"p":3}`, 0, inSchema, nil, "") + // Input is still validated. + testToolForSchema[in, any](t, &Tool{}, `{"p":"x"}`, 0, inSchema, nil, `want "integer"`) + // Tool sets input schema: that is what's used. + testToolForSchema[in, any](t, &Tool{InputSchema: inSchema2}, `{"p":3}`, 0, inSchema2, nil, `want "string"`) + // Tool sets output schema: that is what's used, and validation happens. + testToolForSchema[in, any](t, &Tool{OutputSchema: outSchema2}, `{"p":3}`, out{true}, + inSchema, outSchema2, `want "integer"`) + + // Check a slightly more complicated case. + type weatherOutput struct { + Summary string + AsOf time.Time + Source string + } + testToolForSchema[any](t, &Tool{}, `{}`, weatherOutput{}, + &schema{Type: "object"}, + &schema{ + Type: "object", + Required: []string{"Summary", "AsOf", "Source"}, + AdditionalProperties: falseSchema, + Properties: map[string]*schema{ + "Summary": {Type: "string"}, + "AsOf": {Type: "string"}, + "Source": {Type: "string"}, + }, + }, + "") +} diff --git a/mcp/session.go b/mcp/session.go new file mode 100644 index 00000000..dcf9888c --- /dev/null +++ b/mcp/session.go @@ -0,0 +1,29 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +// hasSessionID is the interface which, if implemented by connections, informs +// the session about their session ID. +// +// TODO(rfindley): remove SessionID methods from connections, when it doesn't +// make sense. Or remove it from the Sessions entirely: why does it even need +// to be exposed? +type hasSessionID interface { + SessionID() string +} + +// ServerSessionState is the state of a session. +type ServerSessionState struct { + // InitializeParams are the parameters from 'initialize'. + InitializeParams *InitializeParams `json:"initializeParams"` + + // InitializedParams are the parameters from 'notifications/initialized'. + InitializedParams *InitializedParams `json:"initializedParams"` + + // LogLevel is the logging level for the session. + LogLevel LoggingLevel `json:"logLevel"` + + // TODO: resource subscriptions +} diff --git a/mcp/shared.go b/mcp/shared.go index db871ca8..69b8836a 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -4,8 +4,10 @@ // This file contains code shared between client and server, including // method handler and middleware definitions. -// TODO: much of this is here so that we can factor out commonalities using -// generics. Perhaps it can be simplified with reflection. +// +// Much of this is here so that we can factor out commonalities using +// generics. If this becomes unwieldy, it can perhaps be simplified with +// reflection. package mcp @@ -14,68 +16,96 @@ import ( "encoding/json" "fmt" "log" + "net/http" "reflect" "slices" "strings" "time" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) +const ( + // latestProtocolVersion is the latest protocol version that this version of + // the SDK supports. + // + // It is the version that the client sends in the initialization request, and + // the default version used by the server. + latestProtocolVersion = protocolVersion20250618 + protocolVersion20250618 = "2025-06-18" + protocolVersion20250326 = "2025-03-26" + protocolVersion20241105 = "2024-11-05" +) + +var supportedProtocolVersions = []string{ + protocolVersion20250618, + protocolVersion20250326, + protocolVersion20241105, +} + +// negotiatedVersion returns the effective protocol version to use, given a +// client version. +func negotiatedVersion(clientVersion string) string { + // In general, prefer to use the clientVersion, but if we don't support the + // client's version, use the latest version. + // + // This handles the case where a new spec version is released, and the SDK + // does not support it yet. + if !slices.Contains(supportedProtocolVersions, clientVersion) { + return latestProtocolVersion + } + return clientVersion +} + // A MethodHandler handles MCP messages. // For methods, exactly one of the return values must be nil. // For notifications, both must be nil. -type MethodHandler[S Session] func(ctx context.Context, _ S, method string, params Params) (result Result, err error) - -// A methodHandler is a MethodHandler[Session] for some session. -// We need to give up type safety here, or we will end up with a type cycle somewhere -// else. For example, if Session.methodHandler returned a MethodHandler[Session], -// the compiler would complain. -type methodHandler any // MethodHandler[*ClientSession] | MethodHandler[*ServerSession] +type MethodHandler func(ctx context.Context, method string, req Request) (result Result, err error) -// A Session is either a ClientSession or a ServerSession. +// A Session is either a [ClientSession] or a [ServerSession]. type Session interface { - *ClientSession | *ServerSession // ID returns the session ID, or the empty string if there is none. ID() string sendingMethodInfos() map[string]methodInfo receivingMethodInfos() map[string]methodInfo - sendingMethodHandler() methodHandler - receivingMethodHandler() methodHandler + sendingMethodHandler() MethodHandler + receivingMethodHandler() MethodHandler getConn() *jsonrpc2.Connection } -// Middleware is a function from MethodHandlers to MethodHandlers. -type Middleware[S Session] func(MethodHandler[S]) MethodHandler[S] +// Middleware is a function from [MethodHandler] to [MethodHandler]. +type Middleware func(MethodHandler) MethodHandler // addMiddleware wraps the handler in the middleware functions. -func addMiddleware[S Session](handlerp *MethodHandler[S], middleware []Middleware[S]) { +func addMiddleware(handlerp *MethodHandler, middleware []Middleware) { for _, m := range slices.Backward(middleware) { *handlerp = m(*handlerp) } } -func defaultSendingMethodHandler[S Session](ctx context.Context, session S, method string, params Params) (Result, error) { - info, ok := session.sendingMethodInfos()[method] +func defaultSendingMethodHandler[S Session](ctx context.Context, method string, req Request) (Result, error) { + info, ok := req.GetSession().sendingMethodInfos()[method] if !ok { // This can be called from user code, with an arbitrary value for method. return nil, jsonrpc2.ErrNotHandled } // Notifications don't have results. if strings.HasPrefix(method, "notifications/") { - return nil, session.getConn().Notify(ctx, method, params) + return nil, req.GetSession().getConn().Notify(ctx, method, req.GetParams()) } // Create the result to unmarshal into. // The concrete type of the result is the return type of the receiving function. res := info.newResult() - if err := call(ctx, session.getConn(), method, params, res); err != nil { + if err := call(ctx, req.GetSession().getConn(), method, req.GetParams(), res); err != nil { return nil, err } return res, nil } -// Helper methods to avoid typed nil. +// Helper method to avoid typed nil. func orZero[T any, P *U, U any](p P) T { if p == nil { var zero T @@ -84,16 +114,16 @@ func orZero[T any, P *U, U any](p P) T { return any(p).(T) } -func handleNotify[S Session](ctx context.Context, session S, method string, params Params) error { - mh := session.sendingMethodHandler().(MethodHandler[S]) - _, err := mh(ctx, session, method, params) +func handleNotify(ctx context.Context, method string, req Request) error { + mh := req.GetSession().sendingMethodHandler() + _, err := mh(ctx, method, req) return err } -func handleSend[R Result, S Session](ctx context.Context, s S, method string, params Params) (R, error) { - mh := s.sendingMethodHandler().(MethodHandler[S]) +func handleSend[R Result](ctx context.Context, method string, req Request) (R, error) { + mh := req.GetSession().sendingMethodHandler() // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. - res, err := mh(ctx, s, method, params) + res, err := mh(ctx, method, req) if err != nil { var z R return z, err @@ -102,42 +132,78 @@ func handleSend[R Result, S Session](ctx context.Context, s S, method string, pa } // defaultReceivingMethodHandler is the initial MethodHandler for servers and clients, before being wrapped by middleware. -func defaultReceivingMethodHandler[S Session](ctx context.Context, session S, method string, params Params) (Result, error) { - info, ok := session.receivingMethodInfos()[method] +func defaultReceivingMethodHandler[S Session](ctx context.Context, method string, req Request) (Result, error) { + info, ok := req.GetSession().receivingMethodInfos()[method] if !ok { // This can be called from user code, with an arbitrary value for method. return nil, jsonrpc2.ErrNotHandled } - return info.handleMethod.(MethodHandler[S])(ctx, session, method, params) + return info.handleMethod(ctx, method, req) } -func handleReceive[S Session](ctx context.Context, session S, req *JSONRPCRequest) (Result, error) { - info, ok := session.receivingMethodInfos()[req.Method] - if !ok { - return nil, jsonrpc2.ErrNotHandled +func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Request) (Result, error) { + info, err := checkRequest(jreq, session.receivingMethodInfos()) + if err != nil { + return nil, err } - params, err := info.unmarshalParams(req.Params) + params, err := info.unmarshalParams(jreq.Params) if err != nil { - return nil, fmt.Errorf("handleRequest %q: %w", req.Method, err) + return nil, fmt.Errorf("handling '%s': %w", jreq.Method, err) } - mh := session.receivingMethodHandler().(MethodHandler[S]) + mh := session.receivingMethodHandler() + re, _ := jreq.Extra.(*RequestExtra) + req := info.newRequest(session, params, re) // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. - res, err := mh(ctx, session, req.Method, params) + res, err := mh(ctx, jreq.Method, req) if err != nil { return nil, err } return res, nil } +// checkRequest checks the given request against the provided method info, to +// ensure it is a valid MCP request. +// +// If valid, the relevant method info is returned. Otherwise, a non-nil error +// is returned describing why the request is invalid. +// +// This is extracted from request handling so that it can be called in the +// transport layer to preemptively reject bad requests. +func checkRequest(req *jsonrpc.Request, infos map[string]methodInfo) (methodInfo, error) { + info, ok := infos[req.Method] + if !ok { + return methodInfo{}, fmt.Errorf("%w: %q unsupported", jsonrpc2.ErrNotHandled, req.Method) + } + if info.flags¬ification != 0 && req.IsCall() { + return methodInfo{}, fmt.Errorf("%w: unexpected id for %q", jsonrpc2.ErrInvalidRequest, req.Method) + } + if info.flags¬ification == 0 && !req.IsCall() { + return methodInfo{}, fmt.Errorf("%w: missing id for %q", jsonrpc2.ErrInvalidRequest, req.Method) + } + // missingParamsOK is checked here to catch the common case where "params" is + // missing entirely. + // + // However, it's checked again after unmarshalling to catch the rare but + // possible case where "params" is JSON null (see https://go.dev/issue/33835). + if info.flags&missingParamsOK == 0 && len(req.Params) == 0 { + return methodInfo{}, fmt.Errorf("%w: missing required \"params\"", jsonrpc2.ErrInvalidRequest) + } + return info, nil +} + // methodInfo is information about sending and receiving a method. type methodInfo struct { + // flags is a collection of flags controlling how the JSONRPC method is + // handled. See individual flag values for documentation. + flags methodFlags // Unmarshal params from the wire into a Params struct. // Used on the receive side. unmarshalParams func(json.RawMessage) (Params, error) + newRequest func(Session, Params, *RequestExtra) Request // Run the code when a call to the method is received. // Used on the receive side. - handleMethod methodHandler + handleMethod MethodHandler // Create a pointer to a Result struct. // Used on the send side. newResult func() Result @@ -150,16 +216,60 @@ type methodInfo struct { // - R: results // A typedMethodHandler is like a MethodHandler, but with type information. -type typedMethodHandler[S Session, P Params, R Result] func(context.Context, S, P) (R, error) +type ( + typedClientMethodHandler[P Params, R Result] func(context.Context, *ClientRequest[P]) (R, error) + typedServerMethodHandler[P Params, R Result] func(context.Context, *ServerRequest[P]) (R, error) +) type paramsPtr[T any] interface { *T Params } +type methodFlags int + +const ( + notification methodFlags = 1 << iota // method is a notification, not request + missingParamsOK // params may be missing or null +) + +func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHandler[P, R], flags methodFlags) methodInfo { + mi := newMethodInfo[P, R](flags) + mi.newRequest = func(s Session, p Params, _ *RequestExtra) Request { + r := &ClientRequest[P]{Session: s.(*ClientSession)} + if p != nil { + r.Params = p.(P) + } + return r + } + mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) { + return d(ctx, req.(*ClientRequest[P])) + }) + return mi +} + +func newServerMethodInfo[P paramsPtr[T], R Result, T any](d typedServerMethodHandler[P, R], flags methodFlags) methodInfo { + mi := newMethodInfo[P, R](flags) + mi.newRequest = func(s Session, p Params, re *RequestExtra) Request { + r := &ServerRequest[P]{Session: s.(*ServerSession), Extra: re} + if p != nil { + r.Params = p.(P) + } + return r + } + mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) { + return d(ctx, req.(*ServerRequest[P])) + }) + return mi +} + // newMethodInfo creates a methodInfo from a typedMethodHandler. -func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHandler[S, P, R]) methodInfo { +// +// If isRequest is set, the method is treated as a request rather than a +// notification. +func newMethodInfo[P paramsPtr[T], R Result, T any](flags methodFlags) methodInfo { return methodInfo{ + flags: flags, unmarshalParams: func(m json.RawMessage) (Params, error) { var p P if m != nil { @@ -167,14 +277,18 @@ func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHand return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, p, err) } } + // We must check missingParamsOK here, in addition to checkRequest, to + // catch the edge cases where "params" is set to JSON null. + // See also https://go.dev/issue/33835. + // + // We need to ensure that p is non-null to guard against crashes, as our + // internal code or externally provided handlers may assume that params + // is non-null. + if flags&missingParamsOK == 0 && p == nil { + return nil, fmt.Errorf("%w: missing required \"params\"", jsonrpc2.ErrInvalidRequest) + } return orZero[Params](p), nil }, - handleMethod: MethodHandler[S](func(ctx context.Context, session S, _ string, params Params) (Result, error) { - if params == nil { - return d(ctx, session, nil) - } - return d(ctx, session, params.(P)) - }), // newResult is used on the send side, to construct the value to unmarshal the result into. // R is a pointer to a result struct. There is no way to "unpointer" it without reflection. // TODO(jba): explore generic approaches to this, perhaps by treating R in @@ -185,60 +299,79 @@ func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHand // serverMethod is glue for creating a typedMethodHandler from a method on Server. func serverMethod[P Params, R Result]( - f func(*Server, context.Context, *ServerSession, P) (R, error), -) typedMethodHandler[*ServerSession, P, R] { - return func(ctx context.Context, ss *ServerSession, p P) (R, error) { - return f(ss.server, ctx, ss, p) + f func(*Server, context.Context, *ServerRequest[P]) (R, error), +) typedServerMethodHandler[P, R] { + return func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return f(req.Session.server, ctx, req) } } -// clientMethod is glue for creating a typedMethodHandler from a method on Server. +// clientMethod is glue for creating a typedMethodHandler from a method on Client. func clientMethod[P Params, R Result]( - f func(*Client, context.Context, *ClientSession, P) (R, error), -) typedMethodHandler[*ClientSession, P, R] { - return func(ctx context.Context, cs *ClientSession, p P) (R, error) { - return f(cs.client, ctx, cs, p) + f func(*Client, context.Context, *ClientRequest[P]) (R, error), +) typedClientMethodHandler[P, R] { + return func(ctx context.Context, req *ClientRequest[P]) (R, error) { + return f(req.Session.client, ctx, req) + } +} + +// serverSessionMethod is glue for creating a typedServerMethodHandler from a method on ServerSession. +func serverSessionMethod[P Params, R Result](f func(*ServerSession, context.Context, P) (R, error)) typedServerMethodHandler[P, R] { + return func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return f(req.GetSession().(*ServerSession), ctx, req.Params) } } -// sessionMethod is glue for creating a typedMethodHandler from a method on ServerSession. -func sessionMethod[S Session, P Params, R Result](f func(S, context.Context, P) (R, error)) typedMethodHandler[S, P, R] { - return func(ctx context.Context, sess S, p P) (R, error) { - return f(sess, ctx, p) +// clientSessionMethod is glue for creating a typedMethodHandler from a method on ServerSession. +func clientSessionMethod[P Params, R Result](f func(*ClientSession, context.Context, P) (R, error)) typedClientMethodHandler[P, R] { + return func(ctx context.Context, req *ClientRequest[P]) (R, error) { + return f(req.GetSession().(*ClientSession), ctx, req.Params) } } // Error codes const ( + // TODO: should these be unexported? + CodeResourceNotFound = -32002 // The error code if the method exists and was called properly, but the peer does not support it. CodeUnsupportedMethod = -31001 + // The error code for invalid parameters + CodeInvalidParams = -32602 ) -func callNotificationHandler[S Session, P any](ctx context.Context, h func(context.Context, S, *P), sess S, params *P) (Result, error) { - if h != nil { - h(ctx, sess, params) - } - return nil, nil -} - // notifySessions calls Notify on all the sessions. // Should be called on a copy of the peer sessions. -func notifySessions[S Session](sessions []S, method string, params Params) { +func notifySessions[S Session, P Params](sessions []S, method string, params P) { if sessions == nil { return } - // TODO: make this timeout configurable, or call Notify asynchronously. + // TODO: make this timeout configurable, or call handleNotify asynchronously. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + + // TODO: there's a potential spec violation here, when the feature list + // changes before the session (client or server) is initialized. for _, s := range sessions { - if err := handleNotify(ctx, s, method, params); err != nil { + req := newRequest(s, params) + if err := handleNotify(ctx, method, req); err != nil { // TODO(jba): surface this error better log.Printf("calling %s: %v", method, err) } } } +func newRequest[S Session, P Params](s S, p P) Request { + switch s := any(s).(type) { + case *ClientSession: + return &ClientRequest[P]{Session: s, Params: p} + case *ServerSession: + return &ServerRequest[P]{Session: s, Params: p} + default: + panic("bad session") + } +} + // Meta is additional metadata for requests, responses and other types. type Meta map[string]any @@ -268,8 +401,61 @@ func setProgressToken(p Params, pt any) { m[progressTokenKey] = pt } +// A Request is a method request with parameters and additional information, such as the session. +// Request is implemented by [*ClientRequest] and [*ServerRequest]. +type Request interface { + isRequest() + GetSession() Session + GetParams() Params + // GetExtra returns the Extra field for ServerRequests, and nil for ClientRequests. + GetExtra() *RequestExtra +} + +// A ClientRequest is a request to a client. +type ClientRequest[P Params] struct { + Session *ClientSession + Params P +} + +// A ServerRequest is a request to a server. +type ServerRequest[P Params] struct { + Session *ServerSession + Params P + Extra *RequestExtra +} + +// RequestExtra is extra information included in requests, typically from +// the transport layer. +type RequestExtra struct { + TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any + Header http.Header // header from HTTP request, if any +} + +func (*ClientRequest[P]) isRequest() {} +func (*ServerRequest[P]) isRequest() {} + +func (r *ClientRequest[P]) GetSession() Session { return r.Session } +func (r *ServerRequest[P]) GetSession() Session { return r.Session } + +func (r *ClientRequest[P]) GetParams() Params { return r.Params } +func (r *ServerRequest[P]) GetParams() Params { return r.Params } + +func (r *ClientRequest[P]) GetExtra() *RequestExtra { return nil } +func (r *ServerRequest[P]) GetExtra() *RequestExtra { return r.Extra } + +func serverRequestFor[P Params](s *ServerSession, p P) *ServerRequest[P] { + return &ServerRequest[P]{Session: s, Params: p} +} + +func clientRequestFor[P Params](s *ClientSession, p P) *ClientRequest[P] { + return &ClientRequest[P]{Session: s, Params: p} +} + // Params is a parameter (input) type for an MCP call or notification. type Params interface { + // isParams discourages implementation of Params outside of this package. + isParams() + // GetMeta returns metadata from a value. GetMeta() map[string]any // SetMeta sets the metadata on a value. @@ -291,6 +477,9 @@ type RequestParams interface { // Result is a result of an MCP call. type Result interface { + // isResult discourages implementation of Result outside of this package. + isResult() + // GetMeta returns metadata from a value. GetMeta() map[string]any // SetMeta sets the metadata on a value. @@ -301,6 +490,7 @@ type Result interface { // Those methods cannot return nil, because jsonrpc2 cannot handle nils. type emptyResult struct{} +func (*emptyResult) isResult() {} func (*emptyResult) GetMeta() map[string]any { panic("should never be called") } func (*emptyResult) SetMeta(map[string]any) { panic("should never be called") } diff --git a/mcp/shared_test.go b/mcp/shared_test.go index 5a1d5d02..23818f87 100644 --- a/mcp/shared_test.go +++ b/mcp/shared_test.go @@ -4,86 +4,222 @@ package mcp -import ( - "context" - "encoding/json" - "strings" - "testing" -) - -// TODO(jba): this shouldn't be in this file, but tool_test.go doesn't have access to unexported symbols. -func TestNewServerToolValidate(t *testing.T) { - // Check that the tool returned from NewServerTool properly validates its input schema. - - type req struct { - I int - B bool - S string `json:",omitempty"` - P *int `json:",omitempty"` - } - - dummyHandler := func(context.Context, *ServerSession, *CallToolParamsFor[req]) (*CallToolResultFor[any], error) { - return nil, nil - } - - tool := NewServerTool("test", "test", dummyHandler) - // Need to add the tool to a server to get resolved schemas. - // s := NewServer("", "", nil) - - for _, tt := range []struct { - desc string - args map[string]any - want string // error should contain this string; empty for success - }{ - { - "both required", - map[string]any{"I": 1, "B": true}, - "", - }, - { - "optional", - map[string]any{"I": 1, "B": true, "S": "foo"}, - "", - }, - { - "wrong type", - map[string]any{"I": 1.5, "B": true}, - "cannot unmarshal", - }, - { - "extra property", - map[string]any{"I": 1, "B": true, "C": 2}, - "unknown field", - }, - { - "value for pointer", - map[string]any{"I": 1, "B": true, "P": 3}, - "", - }, - { - "null for pointer", - map[string]any{"I": 1, "B": true, "P": nil}, - "", - }, - } { - t.Run(tt.desc, func(t *testing.T) { - raw, err := json.Marshal(tt.args) - if err != nil { - t.Fatal(err) - } - _, err = tool.rawHandler(context.Background(), nil, - &CallToolParamsFor[json.RawMessage]{Arguments: json.RawMessage(raw)}) - if err == nil && tt.want != "" { - t.Error("got success, wanted failure") - } - if err != nil { - if tt.want == "" { - t.Fatalf("failed with:\n%s\nwanted success", err) - } - if !strings.Contains(err.Error(), tt.want) { - t.Fatalf("got:\n%s\nwanted to contain %q", err, tt.want) - } - } - }) - } -} +// TODO(v0.3.0): rewrite this test. +// func TestToolValidate(t *testing.T) { +// // Check that the tool returned from NewServerTool properly validates its input schema. + +// type req struct { +// I int +// B bool +// S string `json:",omitempty"` +// P *int `json:",omitempty"` +// } + +// dummyHandler := func(context.Context, *CallToolRequest, req) (*CallToolResultFor[any], error) { +// return nil, nil +// } + +// st, err := newServerTool(&Tool{Name: "test", Description: "test"}, dummyHandler) +// if err != nil { +// t.Fatal(err) +// } + +// for _, tt := range []struct { +// desc string +// args map[string]any +// want string // error should contain this string; empty for success +// }{ +// { +// "both required", +// map[string]any{"I": 1, "B": true}, +// "", +// }, +// { +// "optional", +// map[string]any{"I": 1, "B": true, "S": "foo"}, +// "", +// }, +// { +// "wrong type", +// map[string]any{"I": 1.5, "B": true}, +// "cannot unmarshal", +// }, +// { +// "extra property", +// map[string]any{"I": 1, "B": true, "C": 2}, +// "unknown field", +// }, +// { +// "value for pointer", +// map[string]any{"I": 1, "B": true, "P": 3}, +// "", +// }, +// { +// "null for pointer", +// map[string]any{"I": 1, "B": true, "P": nil}, +// "", +// }, +// } { +// t.Run(tt.desc, func(t *testing.T) { +// raw, err := json.Marshal(tt.args) +// if err != nil { +// t.Fatal(err) +// } +// _, err = st.handler(context.Background(), &ServerRequest[*CallToolParamsFor[json.RawMessage]]{ +// Params: &CallToolParamsFor[json.RawMessage]{Arguments: json.RawMessage(raw)}, +// }) +// if err == nil && tt.want != "" { +// t.Error("got success, wanted failure") +// } +// if err != nil { +// if tt.want == "" { +// t.Fatalf("failed with:\n%s\nwanted success", err) +// } +// if !strings.Contains(err.Error(), tt.want) { +// t.Fatalf("got:\n%s\nwanted to contain %q", err, tt.want) +// } +// } +// }) +// } +// } + +// TestNilParamsHandling tests that nil parameters don't cause panic in unmarshalParams. +// This addresses a vulnerability where missing or null parameters could crash the server. +// func TestNilParamsHandling(t *testing.T) { +// // Define test types for clarity +// type TestArgs struct { +// Name string `json:"name"` +// Value int `json:"value"` +// } + +// // Simple test handler +// testHandler := func(ctx context.Context, req *ServerRequest[**GetPromptParams]) (*GetPromptResult, error) { +// result := "processed: " + req.Params.Arguments.Name +// return &CallToolResultFor[string]{StructuredContent: result}, nil +// } + +// methodInfo := newServerMethodInfo(testHandler, missingParamsOK) + +// // Helper function to test that unmarshalParams doesn't panic and handles nil gracefully +// mustNotPanic := func(t *testing.T, rawMsg json.RawMessage, expectNil bool) Params { +// t.Helper() + +// defer func() { +// if r := recover(); r != nil { +// t.Fatalf("unmarshalParams panicked: %v", r) +// } +// }() + +// params, err := methodInfo.unmarshalParams(rawMsg) +// if err != nil { +// t.Fatalf("unmarshalParams failed: %v", err) +// } + +// if expectNil { +// if params != nil { +// t.Fatalf("Expected nil params, got %v", params) +// } +// return params +// } + +// if params == nil { +// t.Fatal("unmarshalParams returned unexpected nil") +// } + +// // Verify the result can be used safely +// typedParams := params.(TestParams) +// _ = typedParams.Name +// _ = typedParams.Arguments.Name +// _ = typedParams.Arguments.Value + +// return params +// } + +// // Test different nil parameter scenarios - with missingParamsOK flag, nil/null should return nil +// t.Run("missing_params", func(t *testing.T) { +// mustNotPanic(t, nil, true) // Expect nil with missingParamsOK flag +// }) + +// t.Run("explicit_null", func(t *testing.T) { +// mustNotPanic(t, json.RawMessage(`null`), true) // Expect nil with missingParamsOK flag +// }) + +// t.Run("empty_object", func(t *testing.T) { +// mustNotPanic(t, json.RawMessage(`{}`), false) // Empty object should create valid params +// }) + +// t.Run("valid_params", func(t *testing.T) { +// rawMsg := json.RawMessage(`{"name":"test","arguments":{"name":"hello","value":42}}`) +// params := mustNotPanic(t, rawMsg, false) + +// // For valid params, also verify the values are parsed correctly +// typedParams := params.(TestParams) +// if typedParams.Name != "test" { +// t.Errorf("Expected name 'test', got %q", typedParams.Name) +// } +// if typedParams.Arguments.Name != "hello" { +// t.Errorf("Expected argument name 'hello', got %q", typedParams.Arguments.Name) +// } +// if typedParams.Arguments.Value != 42 { +// t.Errorf("Expected argument value 42, got %d", typedParams.Arguments.Value) +// } +// }) +// } + +// TestNilParamsEdgeCases tests edge cases to ensure we don't over-fix +// func TestNilParamsEdgeCases(t *testing.T) { +// type TestArgs struct { +// Name string `json:"name"` +// Value int `json:"value"` +// } +// type TestParams = *CallToolParamsFor[TestArgs] + +// testHandler := func(context.Context, *ServerRequest[TestParams]) (*CallToolResultFor[string], error) { +// return &CallToolResultFor[string]{StructuredContent: "test"}, nil +// } + +// methodInfo := newServerMethodInfo(testHandler, missingParamsOK) + +// // These should fail normally, not be treated as nil params +// invalidCases := []json.RawMessage{ +// json.RawMessage(""), // empty string - should error +// json.RawMessage("[]"), // array - should error +// json.RawMessage(`"null"`), // string "null" - should error +// json.RawMessage("0"), // number - should error +// json.RawMessage("false"), // boolean - should error +// } + +// for i, rawMsg := range invalidCases { +// t.Run(fmt.Sprintf("invalid_case_%d", i), func(t *testing.T) { +// params, err := methodInfo.unmarshalParams(rawMsg) +// if err == nil && params == nil { +// t.Error("Should not return nil params without error") +// } +// }) +// } + +// // Test that methods without missingParamsOK flag properly reject nil params +// t.Run("reject_when_params_required", func(t *testing.T) { +// methodInfoStrict := newServerMethodInfo(testHandler, 0) // No missingParamsOK flag + +// testCases := []struct { +// name string +// params json.RawMessage +// }{ +// {"nil_params", nil}, +// {"null_params", json.RawMessage(`null`)}, +// } + +// for _, tc := range testCases { +// t.Run(tc.name, func(t *testing.T) { +// _, err := methodInfoStrict.unmarshalParams(tc.params) +// if err == nil { +// t.Error("Expected error for required params, got nil") +// } +// if !strings.Contains(err.Error(), "missing required \"params\"") { +// t.Errorf("Expected 'missing required params' error, got: %v", err) +// } +// }) +// } +// }) +// } diff --git a/mcp/sse.go b/mcp/sse.go index 0a1f9b1b..7f644918 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -5,19 +5,16 @@ package mcp import ( - "bufio" "bytes" "context" - "errors" "fmt" "io" - "iter" "net/http" "net/url" - "strings" "sync" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) // This file implements support for SSE (HTTP with server-sent events) @@ -40,111 +37,100 @@ import ( // - Read reads off a message queue that is pushed to via POST requests. // - Close causes the hanging GET to exit. -// An event is a server-sent event. -type event struct { - name string - id string - data []byte -} - -func (e event) empty() bool { - return e.name == "" && e.id == "" && len(e.data) == 0 -} - -// writeEvent writes the event to w, and flushes. -func writeEvent(w io.Writer, evt event) (int, error) { - var b bytes.Buffer - if evt.name != "" { - fmt.Fprintf(&b, "event: %s\n", evt.name) - } - if evt.id != "" { - fmt.Fprintf(&b, "id: %s\n", evt.id) - } - fmt.Fprintf(&b, "data: %s\n\n", string(evt.data)) - n, err := w.Write(b.Bytes()) - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - return n, err -} - // SSEHandler is an http.Handler that serves SSE-based MCP sessions as defined by // the [2024-11-05 version] of the MCP spec. // // [2024-11-05 version]: https://modelcontextprotocol.io/specification/2024-11-05/basic/transports type SSEHandler struct { getServer func(request *http.Request) *Server + opts SSEOptions onConnection func(*ServerSession) // for testing; must not block mu sync.Mutex sessions map[string]*SSEServerTransport } +// SSEOptions specifies options for an [SSEHandler]. +// for now, it is empty, but may be extended in future. +// https://github.com/modelcontextprotocol/go-sdk/issues/507 +type SSEOptions struct{} + // NewSSEHandler returns a new [SSEHandler] that creates and manages MCP // sessions created via incoming HTTP requests. // // Sessions are created when the client issues a GET request to the server, // which must accept text/event-stream responses (server-sent events). // For each such request, a new [SSEServerTransport] is created with a distinct -// messages endpoint, and connected to the server returned by getServer. It is -// up to the user whether getServer returns a distinct [Server] for each new -// request, or reuses an existing server. -// +// messages endpoint, and connected to the server returned by getServer. // The SSEHandler also handles requests to the message endpoints, by // delegating them to the relevant server transport. // -// TODO(rfindley): add options. -func NewSSEHandler(getServer func(request *http.Request) *Server) *SSEHandler { - return &SSEHandler{ +// The getServer function may return a distinct [Server] for each new +// request, or reuse an existing server. If it returns nil, the handler +// will return a 400 Bad Request. +func NewSSEHandler(getServer func(request *http.Request) *Server, opts *SSEOptions) *SSEHandler { + s := &SSEHandler{ getServer: getServer, sessions: make(map[string]*SSEServerTransport), } + + if opts != nil { + s.opts = *opts + } + + return s } // A SSEServerTransport is a logical SSE session created through a hanging GET // request. // +// Use [SSEServerTransport.Connect] to initiate the flow of messages. +// // When connected, it returns the following [Connection] implementation: // - Writes are SSE 'message' events to the GET response. // - Reads are received from POSTs to the session endpoint, via // [SSEServerTransport.ServeHTTP]. // - Close terminates the hanging GET. -type SSEServerTransport struct { - endpoint string - incoming chan JSONRPCMessage // queue of incoming messages; never closed - - // We must guard both pushes to the incoming queue and writes to the response - // writer, because incoming POST requests are arbitrarily concurrent and we - // need to ensure we don't write push to the queue, or write to the - // ResponseWriter, after the session GET request exits. - mu sync.Mutex - w http.ResponseWriter // the hanging response body - closed bool // set when the stream is closed - done chan struct{} // closed when the connection is closed -} - -// NewSSEServerTransport creates a new SSE transport for the given messages -// endpoint, and hanging GET response. -// -// Use [SSEServerTransport.Connect] to initiate the flow of messages. // // The transport is itself an [http.Handler]. It is the caller's responsibility // to ensure that the resulting transport serves HTTP requests on the given // session endpoint. // +// Each SSEServerTransport may be connected (via [Server.Connect]) at most +// once, since [SSEServerTransport.ServeHTTP] serves messages to the connected +// session. +// // Most callers should instead use an [SSEHandler], which transparently handles // the delegation to SSEServerTransports. -func NewSSEServerTransport(endpoint string, w http.ResponseWriter) *SSEServerTransport { - return &SSEServerTransport{ - endpoint: endpoint, - w: w, - incoming: make(chan JSONRPCMessage, 100), - done: make(chan struct{}), - } +type SSEServerTransport struct { + // Endpoint is the endpoint for this session, where the client can POST + // messages. + Endpoint string + + // Response is the hanging response body to the incoming GET request. + Response http.ResponseWriter + + // incoming is the queue of incoming messages. + // It is never closed, and by convention, incoming is non-nil if and only if + // the transport is connected. + incoming chan jsonrpc.Message + + // We must guard both pushes to the incoming queue and writes to the response + // writer, because incoming POST requests are arbitrarily concurrent and we + // need to ensure we don't write push to the queue, or write to the + // ResponseWriter, after the session GET request exits. + mu sync.Mutex // also guards writes to Response + closed bool // set when the stream is closed + done chan struct{} // closed when the connection is closed } // ServeHTTP handles POST requests to the transport endpoint. func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if t.incoming == nil { + http.Error(w, "session not connected", http.StatusInternalServerError) + return + } + // Read and parse the message. data, err := io.ReadAll(req.Body) if err != nil { @@ -159,6 +145,12 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) http.Error(w, "failed to parse body", http.StatusBadRequest) return } + if req, ok := msg.(*jsonrpc.Request); ok { + if _, err := checkRequest(req, serverMethodInfos); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + } select { case t.incoming <- msg: w.WriteHeader(http.StatusAccepted) @@ -170,16 +162,19 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) // Connect sends the 'endpoint' event to the client. // See [SSEServerTransport] for more details on the [Connection] implementation. func (t *SSEServerTransport) Connect(context.Context) (Connection, error) { - t.mu.Lock() - _, err := writeEvent(t.w, event{ - name: "endpoint", - data: []byte(t.endpoint), + if t.incoming != nil { + return nil, fmt.Errorf("already connected") + } + t.incoming = make(chan jsonrpc.Message, 100) + t.done = make(chan struct{}) + _, err := writeEvent(t.Response, Event{ + Name: "endpoint", + Data: []byte(t.Endpoint), }) - t.mu.Unlock() if err != nil { return nil, err } - return sseServerConn{t}, nil + return &sseServerConn{t: t}, nil } func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -227,7 +222,7 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - transport := NewSSEServerTransport(endpoint.RequestURI(), w) + transport := &SSEServerTransport{Endpoint: endpoint.RequestURI(), Response: w} // The session is terminated when the request exits. h.mu.Lock() @@ -239,9 +234,13 @@ func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { h.mu.Unlock() }() - // TODO(hxjiang): getServer returns nil will panic. server := h.getServer(req) - ss, err := server.Connect(req.Context(), transport) + if server == nil { + // The getServer argument to NewSSEHandler returned nil. + http.Error(w, "no server available", http.StatusBadRequest) + return + } + ss, err := server.Connect(req.Context(), transport, nil) if err != nil { http.Error(w, "connection failed", http.StatusInternalServerError) return @@ -264,10 +263,10 @@ type sseServerConn struct { } // TODO(jba): get the session ID. (Not urgent because SSE transports have been removed from the spec.) -func (s sseServerConn) SessionID() string { return "" } +func (s *sseServerConn) SessionID() string { return "" } // Read implements jsonrpc2.Reader. -func (s sseServerConn) Read(ctx context.Context) (JSONRPCMessage, error) { +func (s *sseServerConn) Read(ctx context.Context) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -279,7 +278,7 @@ func (s sseServerConn) Read(ctx context.Context) (JSONRPCMessage, error) { } // Write implements jsonrpc2.Writer. -func (s sseServerConn) Write(ctx context.Context, msg JSONRPCMessage) error { +func (s *sseServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { if ctx.Err() != nil { return ctx.Err() } @@ -299,7 +298,7 @@ func (s sseServerConn) Write(ctx context.Context, msg JSONRPCMessage) error { return io.EOF } - _, err = writeEvent(s.t.w, event{name: "message", data: data}) + _, err = writeEvent(s.t.Response, Event{Name: "message", Data: data}) return err } @@ -308,7 +307,7 @@ func (s sseServerConn) Write(ctx context.Context, msg JSONRPCMessage) error { // It must be safe to call Close more than once, as the close may // asynchronously be initiated by either the server closing its connection, or // by the hanging GET exiting. -func (s sseServerConn) Close() error { +func (s *sseServerConn) Close() error { s.t.mu.Lock() defer s.t.mu.Unlock() if !s.t.closed { @@ -324,43 +323,25 @@ func (s sseServerConn) Close() error { // // https://modelcontextprotocol.io/specification/2024-11-05/basic/transports type SSEClientTransport struct { - sseEndpoint *url.URL - opts SSEClientTransportOptions -} + // Endpoint is the SSE endpoint to connect to. + Endpoint string -// SSEClientTransportOptions provides options for the [NewSSEClientTransport] -// constructor. -type SSEClientTransportOptions struct { // HTTPClient is the client to use for making HTTP requests. If nil, // http.DefaultClient is used. HTTPClient *http.Client } -// NewSSEClientTransport returns a new client transport that connects to the -// SSE server at the provided URL. -// -// NewSSEClientTransport panics if the given URL is invalid. -func NewSSEClientTransport(baseURL string, opts *SSEClientTransportOptions) *SSEClientTransport { - url, err := url.Parse(baseURL) - if err != nil { - panic(fmt.Sprintf("invalid base url: %v", err)) - } - t := &SSEClientTransport{ - sseEndpoint: url, - } - if opts != nil { - t.opts = *opts - } - return t -} - // Connect connects through the client endpoint. func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { - req, err := http.NewRequestWithContext(ctx, "GET", c.sseEndpoint.String(), nil) + parsedURL, err := url.Parse(c.Endpoint) + if err != nil { + return nil, fmt.Errorf("invalid endpoint: %v", err) + } + req, err := http.NewRequestWithContext(ctx, "GET", c.Endpoint, nil) if err != nil { return nil, err } - httpClient := c.opts.HTTPClient + httpClient := c.HTTPClient if httpClient == nil { httpClient = http.DefaultClient } @@ -371,18 +352,18 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { } msgEndpoint, err := func() (*url.URL, error) { - var evt event + var evt Event for evt, err = range scanEvents(resp.Body) { break } if err != nil { return nil, err } - if evt.name != "endpoint" { - return nil, fmt.Errorf("first event is %q, want %q", evt.name, "endpoint") + if evt.Name != "endpoint" { + return nil, fmt.Errorf("first event is %q, want %q", evt.Name, "endpoint") } - raw := string(evt.data) - return c.sseEndpoint.Parse(raw) + raw := string(evt.Data) + return parsedURL.Parse(raw) }() if err != nil { resp.Body.Close() @@ -391,7 +372,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { // From here on, the stream takes ownership of resp.Body. s := &sseClientConn{ - sseEndpoint: c.sseEndpoint, + client: httpClient, msgEndpoint: msgEndpoint, incoming: make(chan []byte, 100), body: resp.Body, @@ -406,7 +387,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { return } select { - case s.incoming <- evt.data: + case s.incoming <- evt.Data: case <-s.done: return } @@ -416,104 +397,15 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { return s, nil } -// scanEvents iterates SSE events in the given scanner. The iterated error is -// terminal: if encountered, the stream is corrupt or broken and should no -// longer be used. -// -// TODO(rfindley): consider a different API here that makes failure modes more -// apparent. -func scanEvents(r io.Reader) iter.Seq2[event, error] { - scanner := bufio.NewScanner(r) - const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size - scanner.Buffer(nil, maxTokenSize) - - // TODO: investigate proper behavior when events are out of order, or have - // non-standard names. - var ( - eventKey = []byte("event") - idKey = []byte("id") - dataKey = []byte("data") - ) - - return func(yield func(event, error) bool) { - // iterate event from the wire. - // https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#examples - // - // - `key: value` line records. - // - Consecutive `data: ...` fields are joined with newlines. - // - Unrecognized fields are ignored. Since we only care about 'event', 'id', and - // 'data', these are the only three we consider. - // - Lines starting with ":" are ignored. - // - Records are terminated with two consecutive newlines. - var ( - evt event - dataBuf *bytes.Buffer // if non-nil, preceding field was also data - ) - flushData := func() { - if dataBuf != nil { - evt.data = dataBuf.Bytes() - dataBuf = nil - } - } - for scanner.Scan() { - line := scanner.Bytes() - if len(line) == 0 { - flushData() - // \n\n is the record delimiter - if !evt.empty() && !yield(evt, nil) { - return - } - evt = event{} - continue - } - before, after, found := bytes.Cut(line, []byte{':'}) - if !found { - yield(event{}, fmt.Errorf("malformed line in SSE stream: %q", string(line))) - return - } - if !bytes.Equal(before, dataKey) { - flushData() - } - switch { - case bytes.Equal(before, eventKey): - evt.name = strings.TrimSpace(string(after)) - case bytes.Equal(before, idKey): - evt.id = strings.TrimSpace(string(after)) - case bytes.Equal(before, dataKey): - data := bytes.TrimSpace(after) - if dataBuf != nil { - dataBuf.WriteByte('\n') - dataBuf.Write(data) - } else { - dataBuf = new(bytes.Buffer) - dataBuf.Write(data) - } - } - } - if err := scanner.Err(); err != nil { - if errors.Is(err, bufio.ErrTooLong) { - err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize) - } - if !yield(event{}, err) { - return - } - } - flushData() - if !evt.empty() { - yield(evt, nil) - } - } -} - // An sseClientConn is a logical jsonrpc2 connection that implements the client // half of the SSE protocol: // - Writes are POSTS to the session endpoint. // - Reads are SSE 'message' events, and pushes them onto a buffered channel. // - Close terminates the GET request. type sseClientConn struct { - sseEndpoint *url.URL // SSE endpoint for the GET - msgEndpoint *url.URL // session endpoint for POSTs - incoming chan []byte // queue of incoming messages + client *http.Client // HTTP client to use for requests + msgEndpoint *url.URL // session endpoint for POSTs + incoming chan []byte // queue of incoming messages mu sync.Mutex body io.ReadCloser // body of the hanging GET @@ -530,7 +422,7 @@ func (c *sseClientConn) isDone() bool { return c.closed } -func (c *sseClientConn) Read(ctx context.Context) (JSONRPCMessage, error) { +func (c *sseClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() @@ -551,7 +443,7 @@ func (c *sseClientConn) Read(ctx context.Context) (JSONRPCMessage, error) { } } -func (c *sseClientConn) Write(ctx context.Context, msg JSONRPCMessage) error { +func (c *sseClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { data, err := jsonrpc2.EncodeMessage(msg) if err != nil { return err @@ -564,7 +456,7 @@ func (c *sseClientConn) Write(ctx context.Context, msg JSONRPCMessage) error { return err } req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) + resp, err := c.client.Do(req) if err != nil { return err } diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index 70f84c3e..6132d31e 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -15,29 +15,30 @@ import ( ) type AddParams struct { - X, Y int + X int `json:"x"` + Y int `json:"y"` } -func Add(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[AddParams]) (*mcp.CallToolResultFor[any], error) { - return &mcp.CallToolResultFor[any]{ +func Add(ctx context.Context, req *mcp.CallToolRequest, args AddParams) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("%d", params.Arguments.X+params.Arguments.Y)}, + &mcp.TextContent{Text: fmt.Sprintf("%d", args.X+args.Y)}, }, - }, nil + }, nil, nil } func ExampleSSEHandler() { - server := mcp.NewServer("adder", "v0.0.1", nil) - server.AddTools(mcp.NewServerTool("add", "add two numbers", Add)) + server := mcp.NewServer(&mcp.Implementation{Name: "adder", Version: "v0.0.1"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: "add", Description: "add two numbers"}, Add) - handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { return server }) + handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { return server }, nil) httpServer := httptest.NewServer(handler) defer httpServer.Close() ctx := context.Background() - transport := mcp.NewSSEClientTransport(httpServer.URL, nil) - client := mcp.NewClient("test", "v1.0.0", nil) - cs, err := client.Connect(ctx, transport) + transport := &mcp.SSEClientTransport{Endpoint: httpServer.URL} + client := mcp.NewClient(&mcp.Implementation{Name: "test", Version: "v1.0.0"}, nil) + cs, err := client.Connect(ctx, transport, nil) if err != nil { log.Fatal(err) } diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 23621931..25435ff3 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -5,11 +5,13 @@ package mcp import ( + "bytes" "context" "fmt" + "io" "net/http" "net/http/httptest" - "strings" + "sync/atomic" "testing" "github.com/google/go-cmp/cmp" @@ -19,32 +21,43 @@ func TestSSEServer(t *testing.T) { for _, closeServerFirst := range []bool{false, true} { t.Run(fmt.Sprintf("closeServerFirst=%t", closeServerFirst), func(t *testing.T) { ctx := context.Background() - server := NewServer("testServer", "v1.0.0", nil) - server.AddTools(NewServerTool("greet", "say hi", sayHi)) + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "greet"}, sayHi) - sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }) + sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }, nil) - conns := make(chan *ServerSession, 1) - sseHandler.onConnection = func(cc *ServerSession) { + serverSessions := make(chan *ServerSession, 1) + sseHandler.onConnection = func(ss *ServerSession) { select { - case conns <- cc: + case serverSessions <- ss: default: } } httpServer := httptest.NewServer(sseHandler) defer httpServer.Close() - clientTransport := NewSSEClientTransport(httpServer.URL, nil) + var customClientUsed int64 + customClient := &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + atomic.AddInt64(&customClientUsed, 1) + return http.DefaultTransport.RoundTrip(req) + }), + } + + clientTransport := &SSEClientTransport{ + Endpoint: httpServer.URL, + HTTPClient: customClient, + } - c := NewClient("testClient", "v1.0.0", nil) - cs, err := c.Connect(ctx, clientTransport) + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, clientTransport, nil) if err != nil { t.Fatal(err) } if err := cs.Ping(ctx, nil); err != nil { t.Fatal(err) } - ss := <-conns + ss := <-serverSessions gotHi, err := cs.CallTool(ctx, &CallToolParams{ Name: "greet", Arguments: map[string]any{"Name": "user"}, @@ -57,10 +70,48 @@ func TestSSEServer(t *testing.T) { &TextContent{Text: "hi user"}, }, } - if diff := cmp.Diff(wantHi, gotHi); diff != "" { + if diff := cmp.Diff(wantHi, gotHi, ctrCmpOpts...); diff != "" { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) } + // Verify that customClient was used + if atomic.LoadInt64(&customClientUsed) == 0 { + t.Error("Expected custom HTTP client to be used, but it wasn't") + } + + t.Run("badrequests", func(t *testing.T) { + msgEndpoint := cs.mcpConn.(*sseClientConn).msgEndpoint.String() + + // Test some invalid data, and verify that we get 400s. + badRequests := []struct { + name string + body string + responseContains string + }{ + {"not a method", `{"jsonrpc":"2.0", "method":"notamethod"}`, "not handled"}, + {"missing ID", `{"jsonrpc":"2.0", "method":"ping"}`, "missing id"}, + } + for _, r := range badRequests { + t.Run(r.name, func(t *testing.T) { + resp, err := http.Post(msgEndpoint, "application/json", bytes.NewReader([]byte(r.body))) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if got, want := resp.StatusCode, http.StatusBadRequest; got != want { + t.Errorf("Sending bad request %q: got status %d, want %d", r.body, got, want) + } + result, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Reading response: %v", err) + } + if !bytes.Contains(result, []byte(r.responseContains)) { + t.Errorf("Response body does not contain %q:\n%s", r.responseContains, string(result)) + } + }) + } + }) + // Test that closing either end of the connection terminates the other // end. if closeServerFirst { @@ -74,91 +125,9 @@ func TestSSEServer(t *testing.T) { } } -func TestScanEvents(t *testing.T) { - tests := []struct { - name string - input string - want []event - wantErr string - }{ - { - name: "simple event", - input: "event: message\nid: 1\ndata: hello\n\n", - want: []event{ - {name: "message", id: "1", data: []byte("hello")}, - }, - }, - { - name: "multiple data lines", - input: "data: line 1\ndata: line 2\n\n", - want: []event{ - {data: []byte("line 1\nline 2")}, - }, - }, - { - name: "multiple events", - input: "data: first\n\nevent: second\ndata: second\n\n", - want: []event{ - {data: []byte("first")}, - {name: "second", data: []byte("second")}, - }, - }, - { - name: "no trailing newline", - input: "data: hello", - want: []event{ - {data: []byte("hello")}, - }, - }, - { - name: "malformed line", - input: "invalid line\n\n", - wantErr: "malformed line", - }, - } +// roundTripperFunc is a helper to create a custom RoundTripper +type roundTripperFunc func(*http.Request) (*http.Response, error) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := strings.NewReader(tt.input) - var got []event - var err error - for e, err2 := range scanEvents(r) { - if err2 != nil { - err = err2 - break - } - got = append(got, e) - } - - if tt.wantErr != "" { - if err == nil { - t.Fatalf("scanEvents() got nil error, want error containing %q", tt.wantErr) - } - if !strings.Contains(err.Error(), tt.wantErr) { - t.Fatalf("scanEvents() error = %q, want containing %q", err, tt.wantErr) - } - return - } - - if err != nil { - t.Fatalf("scanEvents() returned unexpected error: %v", err) - } - - if len(got) != len(tt.want) { - t.Fatalf("scanEvents() got %d events, want %d", len(got), len(tt.want)) - } - - for i := range got { - if g, w := got[i].name, tt.want[i].name; g != w { - t.Errorf("event %d: name = %q, want %q", i, g, w) - } - if g, w := got[i].id, tt.want[i].id; g != w { - t.Errorf("event %d: id = %q, want %q", i, g, w) - } - if g, w := string(got[i].data), string(tt.want[i].data); g != w { - t.Errorf("event %d: data = %q, want %q", i, g, w) - } - } - }) - } +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) } diff --git a/mcp/streamable.go b/mcp/streamable.go index da950fb2..4ab343b2 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -7,15 +7,29 @@ package mcp import ( "bytes" "context" + "encoding/json" + "errors" "fmt" "io" + "iter" + "math" + "math/rand/v2" "net/http" + "slices" "strconv" "strings" "sync" "sync/atomic" + "time" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +const ( + protocolVersionHeader = "Mcp-Protocol-Version" + sessionIDHeader = "Mcp-Session-Id" ) // A StreamableHTTPHandler is an http.Handler that serves streamable MCP @@ -24,43 +38,67 @@ import ( // [MCP spec]: https://modelcontextprotocol.io/2025/03/26/streamable-http-transport.html type StreamableHTTPHandler struct { getServer func(*http.Request) *Server + opts StreamableHTTPOptions + + onTransportDeletion func(sessionID string) // for testing only - sessionsMu sync.Mutex - sessions map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) + mu sync.Mutex + // TODO: we should store the ServerSession along with the transport, because + // we need to cancel keepalive requests when closing the transport. + transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header) } -// StreamableHTTPOptions is a placeholder options struct for future -// configuration of the StreamableHTTP handler. +// StreamableHTTPOptions configures the StreamableHTTPHandler. type StreamableHTTPOptions struct { - // TODO(rfindley): support configurable session ID generation and event - // store, session retention, and event retention. + // Stateless controls whether the session is 'stateless'. + // + // A stateless server does not validate the Mcp-Session-Id header, and uses a + // temporary session with default initialization parameters. Any + // server->client request is rejected immediately as there's no way for the + // client to respond. Server->Client notifications may reach the client if + // they are made in the context of an incoming request, as described in the + // documentation for [StreamableServerTransport]. + Stateless bool + + // TODO(#148): support session retention (?) + + // JSONResponse causes streamable responses to return application/json rather + // than text/event-stream ([§2.1.5] of the spec). + // + // [§2.1.5]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server + JSONResponse bool } // NewStreamableHTTPHandler returns a new [StreamableHTTPHandler]. // // The getServer function is used to create or look up servers for new // sessions. It is OK for getServer to return the same server multiple times. +// If getServer returns nil, a 400 Bad Request will be served. func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *StreamableHTTPOptions) *StreamableHTTPHandler { - return &StreamableHTTPHandler{ - getServer: getServer, - sessions: make(map[string]*StreamableServerTransport), + h := &StreamableHTTPHandler{ + getServer: getServer, + transports: make(map[string]*StreamableServerTransport), + } + if opts != nil { + h.opts = *opts } + return h } // closeAll closes all ongoing sessions. // // TODO(rfindley): investigate the best API for callers to configure their -// session lifecycle. +// session lifecycle. (?) // // Should we allow passing in a session store? That would allow the handler to // be stateless. func (h *StreamableHTTPHandler) closeAll() { - h.sessionsMu.Lock() - defer h.sessionsMu.Unlock() - for _, s := range h.sessions { - s.Close() + h.mu.Lock() + defer h.mu.Unlock() + for _, s := range h.transports { + s.connection.Close() } - h.sessions = nil + h.transports = nil } func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -70,9 +108,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque var jsonOK, streamOK bool for _, c := range accept { switch strings.TrimSpace(c) { - case "application/json": + case "application/json", "application/*": + jsonOK = true + case "text/event-stream", "text/*": + streamOK = true + case "*/*": jsonOK = true - case "text/event-stream": streamOK = true } } @@ -82,172 +123,385 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "Accept must contain 'text/event-stream' for GET requests", http.StatusBadRequest) return } - } else if !jsonOK || !streamOK { + } else if (!jsonOK || !streamOK) && req.Method != http.MethodDelete { // TODO: consolidate with handling of http method below. http.Error(w, "Accept must contain both 'application/json' and 'text/event-stream'", http.StatusBadRequest) return } - var session *StreamableServerTransport - if id := req.Header.Get("Mcp-Session-Id"); id != "" { - h.sessionsMu.Lock() - session, _ = h.sessions[id] - h.sessionsMu.Unlock() - if session == nil { + sessionID := req.Header.Get(sessionIDHeader) + var transport *StreamableServerTransport + if sessionID != "" { + h.mu.Lock() + transport = h.transports[sessionID] + h.mu.Unlock() + if transport == nil && !h.opts.Stateless { + // Unless we're in 'stateless' mode, which doesn't perform any Session-ID + // validation, we require that the session ID matches a known session. + // + // In stateless mode, a temporary transport is be created below. http.Error(w, "session not found", http.StatusNotFound) return } } - // TODO(rfindley): simplify the locking so that each request has only one - // critical section. if req.Method == http.MethodDelete { - if session == nil { - // => Mcp-Session-Id was not set; else we'd have returned NotFound above. - http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) + if sessionID == "" { + http.Error(w, "Bad Request: DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) return } - h.sessionsMu.Lock() - delete(h.sessions, session.id) - h.sessionsMu.Unlock() - session.Close() + if transport != nil { // transport may be nil in stateless mode + h.mu.Lock() + delete(h.transports, transport.SessionID) + h.mu.Unlock() + transport.connection.Close() + } w.WriteHeader(http.StatusNoContent) return } switch req.Method { case http.MethodPost, http.MethodGet: + if req.Method == http.MethodGet && (h.opts.Stateless || sessionID == "") { + http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed) + return + } default: - w.Header().Set("Allow", "GET, POST") - http.Error(w, "unsupported method", http.StatusMethodNotAllowed) + w.Header().Set("Allow", "GET, POST, DELETE") + http.Error(w, "Method Not Allowed: streamable MCP servers support GET, POST, and DELETE requests", http.StatusMethodNotAllowed) return } - if session == nil { - s := NewStreamableServerTransport(randText()) + // [§2.7] of the spec (2025-06-18) states: + // + // "If using HTTP, the client MUST include the MCP-Protocol-Version: + // HTTP header on all subsequent requests to the MCP + // server, allowing the MCP server to respond based on the MCP protocol + // version. + // + // For example: MCP-Protocol-Version: 2025-06-18 + // The protocol version sent by the client SHOULD be the one negotiated during + // initialization. + // + // For backwards compatibility, if the server does not receive an + // MCP-Protocol-Version header, and has no other way to identify the version - + // for example, by relying on the protocol version negotiated during + // initialization - the server SHOULD assume protocol version 2025-03-26. + // + // If the server receives a request with an invalid or unsupported + // MCP-Protocol-Version, it MUST respond with 400 Bad Request." + // + // Since this wasn't present in the 2025-03-26 version of the spec, this + // effectively means: + // 1. IF the client provides a version header, it must be a supported + // version. + // 2. In stateless mode, where we've lost the state of the initialize + // request, we assume that whatever the client tells us is the truth (or + // assume 2025-03-26 if the client doesn't say anything). + // + // This logic matches the typescript SDK. + // + // [§2.7]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header + protocolVersion := req.Header.Get(protocolVersionHeader) + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + if !slices.Contains(supportedProtocolVersions, protocolVersion) { + http.Error(w, fmt.Sprintf("Bad Request: Unsupported protocol version (supported versions: %s)", strings.Join(supportedProtocolVersions, ",")), http.StatusBadRequest) + return + } + + if transport == nil { server := h.getServer(req) + if server == nil { + // The getServer argument to NewStreamableHTTPHandler returned nil. + http.Error(w, "no server available", http.StatusBadRequest) + return + } + if sessionID == "" { + // In stateless mode, sessionID may be nonempty even if there's no + // existing transport. + sessionID = server.opts.GetSessionID() + } + transport = &StreamableServerTransport{ + SessionID: sessionID, + Stateless: h.opts.Stateless, + jsonResponse: h.opts.JSONResponse, + } + + // To support stateless mode, we initialize the session with a default + // state, so that it doesn't reject subsequent requests. + var connectOpts *ServerSessionOptions + if h.opts.Stateless { + // Peek at the body to see if it is initialize or initialized. + // We want those to be handled as usual. + var hasInitialize, hasInitialized bool + { + // TODO: verify that this allows protocol version negotiation for + // stateless servers. + body, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusInternalServerError) + return + } + req.Body.Close() + + // Reset the body so that it can be read later. + req.Body = io.NopCloser(bytes.NewBuffer(body)) + + msgs, _, err := readBatch(body) + if err == nil { + for _, msg := range msgs { + if req, ok := msg.(*jsonrpc.Request); ok { + switch req.Method { + case methodInitialize: + hasInitialize = true + case notificationInitialized: + hasInitialized = true + } + } + } + } + } + + // If we don't have InitializeParams or InitializedParams in the request, + // set the initial state to a default value. + state := new(ServerSessionState) + if !hasInitialize { + state.InitializeParams = &InitializeParams{ + ProtocolVersion: protocolVersion, + } + } + if !hasInitialized { + state.InitializedParams = new(InitializedParams) + } + state.LogLevel = "info" + connectOpts = &ServerSessionOptions{ + State: state, + } + } else { + // Cleanup is only required in stateful mode, as transportation is + // not stored in the map otherwise. + connectOpts = &ServerSessionOptions{ + onClose: func() { + h.mu.Lock() + delete(h.transports, transport.SessionID) + h.mu.Unlock() + if h.onTransportDeletion != nil { + h.onTransportDeletion(transport.SessionID) + } + }, + } + } + // Pass req.Context() here, to allow middleware to add context values. // The context is detached in the jsonrpc2 library when handling the // long-running stream. - if _, err := server.Connect(req.Context(), s); err != nil { + ss, err := server.Connect(req.Context(), transport, connectOpts) + if err != nil { http.Error(w, "failed connection", http.StatusInternalServerError) return } - h.sessionsMu.Lock() - h.sessions[s.id] = s - h.sessionsMu.Unlock() - session = s + if h.opts.Stateless { + // Stateless mode: close the session when the request exits. + defer ss.Close() // close the fake session after handling the request + } else { + // Otherwise, save the transport so that it can be reused + h.mu.Lock() + h.transports[transport.SessionID] = transport + h.mu.Unlock() + } } - session.ServeHTTP(w, req) + transport.ServeHTTP(w, req) } -// NewStreamableServerTransport returns a new [StreamableServerTransport] with -// the given session ID. -// -// A StreamableServerTransport implements the server-side of the streamable +// A StreamableServerTransport implements the server side of the MCP streamable // transport. // -// TODO(rfindley): consider adding options here, to configure event storage -// policy. -func NewStreamableServerTransport(sessionID string) *StreamableServerTransport { - return &StreamableServerTransport{ - id: sessionID, - incoming: make(chan JSONRPCMessage, 10), - done: make(chan struct{}), - outgoingMessages: make(map[streamID][]*streamableMsg), - signals: make(map[streamID]chan struct{}), - requestStreams: make(map[JSONRPCID]streamID), - streamRequests: make(map[streamID]map[JSONRPCID]struct{}), - } +// Each StreamableServerTransport must be connected (via [Server.Connect]) at +// most once, since [StreamableServerTransport.ServeHTTP] serves messages to +// the connected session. +// +// Reads from the streamable server connection receive messages from http POST +// requests from the client. Writes to the streamable server connection are +// sent either to the hanging POST response, or to the hanging GET, according +// to the following rules: +// - JSON-RPC responses to incoming requests are always routed to the +// appropriate HTTP response. +// - Requests or notifications made with a context.Context value derived from +// an incoming request handler, are routed to the HTTP response +// corresponding to that request, unless it has already terminated, in +// which case they are routed to the hanging GET. +// - Requests or notifications made with a detached context.Context value are +// routed to the hanging GET. +type StreamableServerTransport struct { + // SessionID is the ID of this session. + // + // If SessionID is the empty string, this is a 'stateless' session, which has + // limited ability to communicate with the client. Otherwise, the session ID + // must be globally unique, that is, different from any other session ID + // anywhere, past and future. (We recommend using a crypto random number + // generator to produce one, as with [crypto/rand.Text].) + SessionID string + + // Stateless controls whether the eventstore is 'Stateless'. Server sessions + // connected to a stateless transport are disallowed from making outgoing + // requests. + // + // See also [StreamableHTTPOptions.Stateless]. + Stateless bool + + // Storage for events, to enable stream resumption. + // If nil, a [MemoryEventStore] with the default maximum size will be used. + EventStore EventStore + + // jsonResponse, if set, tells the server to prefer to respond to requests + // using application/json responses rather than text/event-stream. + // + // Specifically, responses will be application/json whenever incoming POST + // request contain only a single message. In this case, notifications or + // requests made within the context of a server request will be sent to the + // hanging GET request, if any. + // + // TODO(rfindley): jsonResponse should be exported, since + // StreamableHTTPOptions.JSONResponse is exported. + jsonResponse bool + + // connection is non-nil if and only if the transport has been connected. + connection *streamableServerConn } -func (t *StreamableServerTransport) SessionID() string { - return t.id +// Connect implements the [Transport] interface. +func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, error) { + if t.connection != nil { + return nil, fmt.Errorf("transport already connected") + } + t.connection = &streamableServerConn{ + sessionID: t.SessionID, + stateless: t.Stateless, + eventStore: t.EventStore, + jsonResponse: t.jsonResponse, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + streams: make(map[string]*stream), + requestStreams: make(map[jsonrpc.ID]string), + } + if t.connection.eventStore == nil { + t.connection.eventStore = NewMemoryEventStore(nil) + } + // Stream 0 corresponds to the hanging 'GET'. + // + // It is always text/event-stream, since it must carry arbitrarily many + // messages. + var err error + t.connection.streams[""], err = t.connection.newStream(ctx, "", false, false) + if err != nil { + return nil, err + } + return t.connection, nil } -// A StreamableServerTransport implements the [Transport] interface for a -// single session. -type StreamableServerTransport struct { - nextStreamID atomic.Int64 // incrementing next stream ID +type streamableServerConn struct { + sessionID string + stateless bool + jsonResponse bool + eventStore EventStore - id string - incoming chan JSONRPCMessage // messages from the client to the server + incoming chan jsonrpc.Message // messages from the client to the server - mu sync.Mutex + mu sync.Mutex // guards all fields below // Sessions are closed exactly once. isDone bool done chan struct{} - // Sessions can have multiple logical connections, corresponding to HTTP - // requests. Additionally, logical sessions may be resumed by subsequent HTTP - // requests, when the session is terminated unexpectedly. + // Sessions can have multiple logical connections (which we call streams), + // corresponding to HTTP requests. Additionally, streams may be resumed by + // subsequent HTTP requests, when the HTTP connection is terminated + // unexpectedly. // - // Therefore, we use a logical connection ID to key the connection state, and + // Therefore, we use a logical stream ID to key the stream state, and // perform the accounting described below when incoming HTTP requests are // handled. + + // streams holds the logical streams for this session, keyed by their ID. + // TODO: streams are never deleted, so the memory for a connection grows without + // bound. If we deleted a stream when the response is sent, we would lose the ability + // to replay if there was a cut just before the response was transmitted. + // Perhaps we could have a TTL for streams that starts just after the response. + streams map[string]*stream + + // requestStreams maps incoming requests to their logical stream ID. // - // The accounting is complicated. It is tempting to merge some of the maps - // below, but they each have different lifecycles, as indicated by Lifecycle: - // comments. + // Lifecycle: requestStreams persist for the duration of the session. // - // TODO(rfindley): simplify. + // TODO: clean up once requests are handled. See the TODO for streams above. + requestStreams map[jsonrpc.ID]string +} - // outgoingMessages is the collection of outgoingMessages messages, keyed by the logical - // stream ID where they should be delivered. - // - // streamID 0 is used for messages that don't correlate with an incoming - // request. +func (c *streamableServerConn) SessionID() string { + return c.sessionID +} + +// A stream is a single logical stream of SSE events within a server session. +// A stream begins with a client request, or with a client GET that has +// no Last-Event-ID header. +// +// A stream ends only when its session ends; we cannot determine its end otherwise, +// since a client may send a GET with a Last-Event-ID that references the stream +// at any time. +type stream struct { + // id is the logical ID for the stream, unique within a session. + // an empty string is used for messages that don't correlate with an incoming request. + id string + + // If isInitialize is set, the stream is in response to an initialize request, + // and therefore should include the session ID header. + isInitialize bool + + // jsonResponse records whether this stream should respond with application/json + // instead of text/event-stream. // - // Lifecycle: outgoingMessages persists for the duration of the session. + // See [StreamableServerTransportOptions.JSONResponse]. + jsonResponse bool + + // signal is a 1-buffered channel, owned by an incoming HTTP request, that signals + // that there are messages available to write into the HTTP response. + // In addition, the presence of a channel guarantees that at most one HTTP response + // can receive messages for a logical stream. After claiming the stream, incoming + // requests should read from the event store, to ensure that no new messages are missed. // - // TODO(rfindley): garbage collect this data. For now, we save all outgoingMessages - // messages for the lifespan of the transport. - outgoingMessages map[streamID][]*streamableMsg - - // signals maps a logical stream ID to a 1-buffered channel, owned by an - // incoming HTTP request, that signals that there are messages available to - // write into the HTTP response. Signals guarantees that at most one HTTP - // response can receive messages for a logical stream. After claiming - // the stream, incoming requests should read from outgoing, to ensure - // that no new messages are missed. + // To simplify locking, signal is an atomic. We need an atomic.Pointer, because + // you can't set an atomic.Value to nil. // - // Lifecycle: signals persists for the duration of an HTTP POST or GET - // request for the given streamID. - signals map[streamID]chan struct{} + // Lifecycle: each channel value persists for the duration of an HTTP POST or + // GET request for the given streamID. + signal atomic.Pointer[chan struct{}] - // requestStreams maps incoming requests to their logical stream ID. - // - // Lifecycle: requestStreams persists for the duration of the session. - // - // TODO(rfindley): clean up once requests are handled. - requestStreams map[JSONRPCID]streamID + // The following mutable fields are protected by the mutex of the containing + // StreamableServerTransport. - // outstandingRequests tracks the set of unanswered incoming RPCs for each logical - // stream. + // streamRequests is the set of unanswered incoming RPCs for the stream. // - // When the server has responded to each request, the stream should be - // closed. - // - // Lifecycle: outstandingRequests values persist as until the requests have been - // replied to by the server. Notably, NOT until they are sent to an HTTP - // response, as delivery is not guaranteed. - streamRequests map[streamID]map[JSONRPCID]struct{} + // Requests persist until their response data has been added to the event store. + requests map[jsonrpc.ID]struct{} } -type streamID int64 - -// a streamableMsg is an SSE event with an index into its logical stream. -type streamableMsg struct { - idx int - event event +func (c *streamableServerConn) newStream(ctx context.Context, id string, isInitialize, jsonResponse bool) (*stream, error) { + if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil { + return nil, err + } + return &stream{ + id: id, + isInitialize: isInitialize, + jsonResponse: jsonResponse, + requests: make(map[jsonrpc.ID]struct{}), + }, nil } -// Connect implements the [Transport] interface. -// -// TODO(rfindley): Connect should return a new object. -func (s *StreamableServerTransport) Connect(context.Context) (Connection, error) { - return s, nil +func signalChanPtr() *chan struct{} { + c := make(chan struct{}, 1) + return &c } // We track the incoming request ID inside the handler context using @@ -266,7 +520,7 @@ func (s *StreamableServerTransport) Connect(context.Context) (Connection, error) // 2. Expose a 'HandlerTransport' interface that allows transports to provide // a handler middleware, so that we don't hard-code this behavior in // ServerSession.handle. -// 3. Add a `func ForRequest(context.Context) JSONRPCID` accessor that lets +// 3. Add a `func ForRequest(context.Context) jsonrpc.ID` accessor that lets // any transport access the incoming request ID. // // For now, by giving only the StreamableServerTransport access to the request @@ -275,46 +529,67 @@ type idContextKey struct{} // ServeHTTP handles a single HTTP request for the session. func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if t.connection == nil { + http.Error(w, "transport not connected", http.StatusInternalServerError) + return + } switch req.Method { case http.MethodGet: - t.serveGET(w, req) + t.connection.serveGET(w, req) case http.MethodPost: - t.servePOST(w, req) + t.connection.servePOST(w, req) default: // Should not be reached, as this is checked in StreamableHTTPHandler.ServeHTTP. w.Header().Set("Allow", "GET, POST") http.Error(w, "unsupported method", http.StatusMethodNotAllowed) + return } } -func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) { +// serveGET streams messages to a hanging http GET, with stream ID and last +// message parsed from the Last-Event-ID header. +// +// It returns an HTTP status code and error message. +func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request) { // connID 0 corresponds to the default GET request. - id, nextIdx := streamID(0), 0 + id := "" + // By default, we haven't seen a last index. Since indices start at 0, we represent + // that by -1. This is incremented just before each event is written, in streamResponse + // around L407. + lastIdx := -1 if len(req.Header.Values("Last-Event-ID")) > 0 { eid := req.Header.Get("Last-Event-ID") var ok bool - id, nextIdx, ok = parseEventID(eid) + id, lastIdx, ok = parseEventID(eid) if !ok { http.Error(w, fmt.Sprintf("malformed Last-Event-ID %q", eid), http.StatusBadRequest) return } - nextIdx++ } - t.mu.Lock() - if _, ok := t.signals[id]; ok { - http.Error(w, "stream ID conflicts with ongoing stream", http.StatusBadRequest) - t.mu.Unlock() + c.mu.Lock() + stream, ok := c.streams[id] + c.mu.Unlock() + if !ok { + http.Error(w, "unknown stream", http.StatusBadRequest) return } - signal := make(chan struct{}, 1) - t.signals[id] = signal - t.mu.Unlock() - - t.streamResponse(w, req, id, nextIdx, signal) + if !stream.signal.CompareAndSwap(nil, signalChanPtr()) { + // The CAS returned false, meaning that the comparison failed: stream.signal is not nil. + http.Error(w, "stream ID conflicts with ongoing stream", http.StatusConflict) + return + } + defer stream.signal.Store(nil) + persistent := id == "" // Only the special stream "" is a hanging get. + c.respondSSE(stream, w, req, lastIdx, persistent) } -func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) { +// servePOST handles an incoming message, and replies with either an outgoing +// message stream or single response object, depending on whether the +// jsonResponse option is set. +// +// It returns an HTTP status code and error message. +func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Request) { if len(req.Header.Values("Last-Event-ID")) > 0 { http.Error(w, "can't send Last-Event-ID for POST request", http.StatusBadRequest) return @@ -330,120 +605,234 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R http.Error(w, "POST requires a non-empty body", http.StatusBadRequest) return } - incoming, _, err := readBatch(body) + incoming, isBatch, err := readBatch(body) if err != nil { http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) return } - requests := make(map[JSONRPCID]struct{}) + + protocolVersion := req.Header.Get(protocolVersionHeader) + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + + if isBatch && protocolVersion >= protocolVersion20250618 { + http.Error(w, fmt.Sprintf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, protocolVersion), http.StatusBadRequest) + return + } + + requests := make(map[jsonrpc.ID]struct{}) + tokenInfo := auth.TokenInfoFromContext(req.Context()) + isInitialize := false for _, msg := range incoming { - if req, ok := msg.(*JSONRPCRequest); ok && req.ID.IsValid() { - requests[req.ID] = struct{}{} + if jreq, ok := msg.(*jsonrpc.Request); ok { + // Preemptively check that this is a valid request, so that we can fail + // the HTTP request. If we didn't do this, a request with a bad method or + // missing ID could be silently swallowed. + if _, err := checkRequest(jreq, serverMethodInfos); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if jreq.Method == methodInitialize { + isInitialize = true + } + jreq.Extra = &RequestExtra{ + TokenInfo: tokenInfo, + Header: req.Header, + } + if jreq.IsCall() { + requests[jreq.ID] = struct{}{} + } } } - // Update accounting for this request. - id := streamID(t.nextStreamID.Add(1)) - signal := make(chan struct{}, 1) - t.mu.Lock() + var stream *stream // if non-nil, used to handle requests + + // If we have requests, we need to handle responses along with any + // notifications or server->client requests made in the course of handling. + // Update accounting for this incoming payload. if len(requests) > 0 { - t.streamRequests[id] = make(map[JSONRPCID]struct{}) - } - for reqID := range requests { - t.requestStreams[reqID] = id - t.streamRequests[id][reqID] = struct{}{} + stream, err = c.newStream(req.Context(), randText(), isInitialize, c.jsonResponse) + if err != nil { + http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) + return + } + c.mu.Lock() + c.streams[stream.id] = stream + stream.requests = requests + for reqID := range requests { + c.requestStreams[reqID] = stream.id + } + c.mu.Unlock() + stream.signal.Store(signalChanPtr()) + defer stream.signal.Store(nil) } - t.signals[id] = signal - t.mu.Unlock() // Publish incoming messages. for _, msg := range incoming { - t.incoming <- msg + c.incoming <- msg } - // TODO(rfindley): consider optimizing for a single incoming request, by - // responding with application/json when there is only a single message in - // the response. - t.streamResponse(w, req, id, 0, signal) + if stream == nil { + w.WriteHeader(http.StatusAccepted) + return + } + + if stream.jsonResponse { + c.respondJSON(stream, w, req) + } else { + c.respondSSE(stream, w, req, -1, false) + } } -func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *http.Request, id streamID, nextIndex int, signal chan struct{}) { - defer func() { - t.mu.Lock() - delete(t.signals, id) - t.mu.Unlock() - }() +func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter, req *http.Request) { + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Content-Type", "application/json") + if c.sessionID != "" && stream.isInitialize { + w.Header().Set(sessionIDHeader, c.sessionID) + } - // Stream resumption: adjust outgoing index based on what the user says - // they've received. - if nextIndex > 0 { - t.mu.Lock() - // Clamp nextIndex to outgoing messages. - outgoing := t.outgoingMessages[id] - if nextIndex > len(outgoing) { - nextIndex = len(outgoing) + var msgs []json.RawMessage + ctx := req.Context() + for msg, err := range c.messages(ctx, stream, false, -1) { + if err != nil { + if ctx.Err() != nil { + w.WriteHeader(http.StatusNoContent) + return + } else { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + } + msgs = append(msgs, msg) + } + var data []byte + if len(msgs) == 1 { + data = []byte(msgs[0]) + } else { + // TODO: add tests for batch responses, or disallow them entirely. + var err error + data, err = json.Marshal(msgs) + if err != nil { + http.Error(w, fmt.Sprintf("internal error marshalling response: %v", err), http.StatusInternalServerError) + return } - t.mu.Unlock() } + _, _ = w.Write(data) // ignore error: client disconnected +} - w.Header().Set("Mcp-Session-Id", t.id) - w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] +// lastIndex is the index of the last seen event if resuming, else -1. +func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int, persistent bool) { + // Accept was checked in [StreamableHTTPHandler] w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] w.Header().Set("Connection", "keep-alive") + if c.sessionID != "" && stream.isInitialize { + w.Header().Set(sessionIDHeader, c.sessionID) + } + if persistent { + // Issue #410: the hanging GET is likely not to receive messages for a long + // time. Ensure that headers are flushed. + // + // For non-persistent requests, delay the writing of the header in case we + // may want to set an error status. + // (see the TODO: this probably isn't worth it). + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + // write one event containing data. writes := 0 -stream: - for { - // Send outgoing messages - t.mu.Lock() - outgoing := t.outgoingMessages[id][nextIndex:] - t.mu.Unlock() - - for _, msg := range outgoing { - if _, err := writeEvent(w, msg.event); err != nil { - // Connection closed or broken. - return - } - writes++ - nextIndex++ - } - - t.mu.Lock() - nOutstanding := len(t.streamRequests[id]) - nOutgoing := len(t.outgoingMessages[id]) - t.mu.Unlock() - // If all requests have been handled and replied to, we can terminate this - // connection. However, in the case of a sequencing violation from the server - // (a send on the request context after the request has been handled), we - // loop until we've written all messages. - // - // TODO(rfindley): should we instead refuse to send messages after the last - // response? Decide, write a test, and change the behavior. - if nextIndex < nOutgoing { - continue // more to send - } - if req.Method == http.MethodPost && nOutstanding == 0 { - if writes == 0 { - // Spec: If the server accepts the input, the server MUST return HTTP - // status code 202 Accepted with no body. - w.WriteHeader(http.StatusAccepted) + write := func(data []byte) bool { + lastIndex++ + e := Event{ + Name: "message", + ID: formatEventID(stream.id, lastIndex), + Data: data, + } + if _, err := writeEvent(w, e); err != nil { + // Connection closed or broken. + // TODO(#170): log when we add server-side logging. + return false + } + writes++ + return true + } + + // Repeatedly collect pending outgoing events and send them. + ctx := req.Context() + for msg, err := range c.messages(ctx, stream, persistent, lastIndex) { + if err != nil { + if ctx.Err() == nil && writes == 0 && !persistent { + // If we haven't yet written the header, we have an opportunity to + // promote an error to an HTTP error. + // + // TODO: This may not matter in practice, in which case we should + // simplify. + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } else { + // TODO(#170): log when we add server-side logging } return } + if !write(msg) { + return + } + } +} - select { - case <-signal: - case <-t.done: - if writes == 0 { - http.Error(w, "session terminated", http.StatusGone) +// messages iterates over messages sent to the current stream. +// +// persistent indicates if it is the main GET listener, which should never be +// terminated. +// lastIndex is the index of the last seen event, iteration begins at lastIndex+1. +// +// The first iterated value is the received JSON message. The second iterated +// value is an error value indicating whether the stream terminated normally. +// Iteration ends at the first non-nil error. +// +// If the stream did not terminate normally, it is either because ctx was +// cancelled, or the connection is closed: check the ctx.Err() to differentiate +// these cases. +func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool, lastIndex int) iter.Seq2[json.RawMessage, error] { + return func(yield func(json.RawMessage, error) bool) { + for { + c.mu.Lock() + nOutstanding := len(stream.requests) + c.mu.Unlock() + for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, lastIndex) { + if err != nil { + yield(nil, err) + return + } + if !yield(data, nil) { + return + } + lastIndex++ } - break stream - case <-req.Context().Done(): - if writes == 0 { - w.WriteHeader(http.StatusNoContent) + // If all requests have been handled and replied to, we should terminate this connection. + // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." + // §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server + // We only want to terminate POSTs, and GETs that are replaying. The general-purpose GET + // (stream ID 0) will never have requests, and should remain open indefinitely. + if nOutstanding == 0 && !persistent { + return + } + + select { + case <-*stream.signal.Load(): // there are new outgoing messages + // return to top of loop + case <-c.done: // session is closed + yield(nil, errors.New("session is closed")) + return + case <-ctx.Done(): + yield(nil, ctx.Err()) + return } - break stream } + } } @@ -454,71 +843,73 @@ stream: // streamID and message index idx. // // See also [parseEventID]. -func formatEventID(sid streamID, idx int) string { - return fmt.Sprintf("%d_%d", sid, idx) +func formatEventID(sid string, idx int) string { + return fmt.Sprintf("%s_%d", sid, idx) } // parseEventID parses a Last-Event-ID value into a logical stream id and // index. // // See also [formatEventID]. -func parseEventID(eventID string) (sid streamID, idx int, ok bool) { +func parseEventID(eventID string) (streamID string, idx int, ok bool) { parts := strings.Split(eventID, "_") if len(parts) != 2 { - return 0, 0, false - } - stream, err := strconv.ParseInt(parts[0], 10, 64) - if err != nil || stream < 0 { - return 0, 0, false + return "", 0, false } - idx, err = strconv.Atoi(parts[1]) + streamID = parts[0] + idx, err := strconv.Atoi(parts[1]) if err != nil || idx < 0 { - return 0, 0, false + return "", 0, false } - return streamID(stream), idx, true + return streamID, idx, true } // Read implements the [Connection] interface. -func (t *StreamableServerTransport) Read(ctx context.Context) (JSONRPCMessage, error) { +func (c *streamableServerConn) Read(ctx context.Context) (jsonrpc.Message, error) { select { case <-ctx.Done(): return nil, ctx.Err() - case msg, ok := <-t.incoming: + case msg, ok := <-c.incoming: if !ok { return nil, io.EOF } return msg, nil - case <-t.done: + case <-c.done: return nil, io.EOF } } // Write implements the [Connection] interface. -func (t *StreamableServerTransport) Write(ctx context.Context, msg JSONRPCMessage) error { +func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) error { + if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() && (c.stateless || c.sessionID == "") { + // Requests aren't possible with stateless servers, or when there's no session ID. + return fmt.Errorf("%w: stateless servers cannot make requests", jsonrpc2.ErrRejected) + } // Find the incoming request that this write relates to, if any. - var forRequest, replyTo JSONRPCID - if resp, ok := msg.(*JSONRPCResponse); ok { + var forRequest jsonrpc.ID + isResponse := false + if resp, ok := msg.(*jsonrpc.Response); ok { // If the message is a response, it relates to its request (of course). forRequest = resp.ID - replyTo = resp.ID + isResponse = true } else { // Otherwise, we check to see if it request was made in the context of an - // ongoing request. This may not be the case if the request way made with + // ongoing request. This may not be the case if the request was made with // an unrelated context. if v := ctx.Value(idContextKey{}); v != nil { - forRequest = v.(JSONRPCID) + forRequest = v.(jsonrpc.ID) } } // Find the logical connection corresponding to this request. // // For messages sent outside of a request context, this is the default - // connection 0. - var forConn streamID + // connection "". + var forStream string if forRequest.IsValid() { - t.mu.Lock() - forConn = t.requestStreams[forRequest] - t.mu.Unlock() + c.mu.Lock() + forStream = c.requestStreams[forRequest] + c.mu.Unlock() } data, err := jsonrpc2.EncodeMessage(msg) @@ -526,41 +917,41 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg JSONRPCMessag return err } - t.mu.Lock() - defer t.mu.Unlock() - if t.isDone { - return fmt.Errorf("session is closed") // TODO: should this be EOF? + c.mu.Lock() + defer c.mu.Unlock() + if c.isDone { + return errors.New("session is closed") + } + + stream := c.streams[forStream] + if stream == nil { + return fmt.Errorf("no stream with ID %s", forStream) } - if _, ok := t.streamRequests[forConn]; !ok && forConn != 0 { - // No outstanding requests for this connection, which means it is logically - // done. This is a sequencing violation from the server, so we should report - // a side-channel error here. Put the message on the general queue to avoid - // dropping messages. - forConn = 0 + // Special case a few conditions where we fall back on stream 0 (the hanging GET): + // + // - if forStream is known, but the associated stream is logically complete + // - if the stream is application/json, but the message is not a response + // + // TODO(rfindley): either of these, particularly the first, might be + // considered a bug in the server. Report it through a side-channel? + if len(stream.requests) == 0 && forStream != "" || stream.jsonResponse && !isResponse { + stream = c.streams[""] } - idx := len(t.outgoingMessages[forConn]) - t.outgoingMessages[forConn] = append(t.outgoingMessages[forConn], &streamableMsg{ - idx: idx, - event: event{ - name: "message", - id: formatEventID(forConn, idx), - data: data, - }, - }) - if replyTo.IsValid() { + if err := c.eventStore.Append(ctx, c.SessionID(), stream.id, data); err != nil { + return fmt.Errorf("error storing event: %w", err) + } + if isResponse { // Once we've put the reply on the queue, it's no longer outstanding. - delete(t.streamRequests[forConn], replyTo) - if len(t.streamRequests[forConn]) == 0 { - delete(t.streamRequests, forConn) - } + delete(stream.requests, forRequest) } - // Signal work. - if c, ok := t.signals[forConn]; ok { + // Signal streamResponse that new work is available. + signalp := stream.signal.Load() + if signalp != nil { select { - case c <- struct{}{}: + case *signalp <- struct{}{}: default: } } @@ -568,12 +959,15 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg JSONRPCMessag } // Close implements the [Connection] interface. -func (t *StreamableServerTransport) Close() error { - t.mu.Lock() - defer t.mu.Unlock() - if !t.isDone { - t.isDone = true - close(t.done) +func (c *streamableServerConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.isDone { + c.isDone = true + close(c.done) + // TODO: find a way to plumb a context here, or an event store with a long-running + // close operation can take arbitrary time. Alternative: impose a fixed timeout here. + return c.eventStore.SessionClosed(context.TODO(), c.sessionID) } return nil } @@ -581,30 +975,26 @@ func (t *StreamableServerTransport) Close() error { // A StreamableClientTransport is a [Transport] that can communicate with an MCP // endpoint serving the streamable HTTP transport defined by the 2025-03-26 // version of the spec. -// -// TODO(rfindley): support retries and resumption tokens. type StreamableClientTransport struct { - url string - opts StreamableClientTransportOptions -} - -// StreamableClientTransportOptions provides options for the -// [NewStreamableClientTransport] constructor. -type StreamableClientTransportOptions struct { - // HTTPClient is the client to use for making HTTP requests. If nil, - // http.DefaultClient is used. + Endpoint string HTTPClient *http.Client + // MaxRetries is the maximum number of times to attempt a reconnect before giving up. + // It defaults to 5. To disable retries, use a negative number. + MaxRetries int } -// NewStreamableClientTransport returns a new client transport that connects to -// the streamable HTTP server at the provided URL. -func NewStreamableClientTransport(url string, opts *StreamableClientTransportOptions) *StreamableClientTransport { - t := &StreamableClientTransport{url: url} - if opts != nil { - t.opts = *opts - } - return t -} +// These settings are not (yet) exposed to the user in +// StreamableClientTransport. +const ( + // reconnectGrowFactor is the multiplicative factor by which the delay increases after each attempt. + // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. + // It must be 1.0 or greater if MaxRetries is greater than 0. + reconnectGrowFactor = 1.5 + // reconnectInitialDelay is the base delay for the first reconnect attempt. + reconnectInitialDelay = 1 * time.Second + // reconnectMaxDelay caps the backoff delay, preventing it from growing indefinitely. + reconnectMaxDelay = 30 * time.Second +) // Connect implements the [Transport] interface. // @@ -615,159 +1005,445 @@ func NewStreamableClientTransport(url string, opts *StreamableClientTransportOpt // When closed, the connection issues a DELETE request to terminate the logical // session. func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, error) { - client := t.opts.HTTPClient + client := t.HTTPClient if client == nil { client = http.DefaultClient } - return &streamableClientConn{ - url: t.url, - client: client, - incoming: make(chan []byte, 100), - done: make(chan struct{}), - }, nil + maxRetries := t.MaxRetries + if maxRetries == 0 { + maxRetries = 5 + } else if maxRetries < 0 { + maxRetries = 0 + } + // Create a new cancellable context that will manage the connection's lifecycle. + // This is crucial for cleanly shutting down the background SSE listener by + // cancelling its blocking network operations, which prevents hangs on exit. + connCtx, cancel := context.WithCancel(context.Background()) + conn := &streamableClientConn{ + url: t.Endpoint, + client: client, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + maxRetries: maxRetries, + ctx: connCtx, + cancel: cancel, + failed: make(chan struct{}), + } + return conn, nil } type streamableClientConn struct { - url string - client *http.Client - incoming chan []byte - done chan struct{} - + url string + client *http.Client + ctx context.Context + cancel context.CancelFunc + incoming chan jsonrpc.Message + maxRetries int + + // Guard calls to Close, as it may be called multiple times. closeOnce sync.Once closeErr error + done chan struct{} // signal graceful termination - mu sync.Mutex - _sessionID string - // bodies map[*http.Response]io.Closer - err error + // Logical reads are distributed across multiple http requests. Whenever any + // of them fails to process their response, we must break the connection, by + // failing the pending Read. + // + // Achieve this by storing the failure message, and signalling when reads are + // broken. See also [streamableClientConn.fail] and + // [streamableClientConn.failure]. + failOnce sync.Once + _failure error + failed chan struct{} // signal failure + + // Guard the initialization state. + mu sync.Mutex + initializedResult *InitializeResult + sessionID string +} + +// errSessionMissing distinguishes if the session is known to not be present on +// the server (see [streamableClientConn.fail]). +// +// TODO(rfindley): should we expose this error value (and its corresponding +// API) to the user? +// +// The spec says that if the server returns 404, clients should reestablish +// a session. For now, we delegate that to the user, but do they need a way to +// differentiate a 'NotFound' error from other errors? +var errSessionMissing = errors.New("session not found") + +var _ clientConnection = (*streamableClientConn)(nil) + +func (c *streamableClientConn) sessionUpdated(state clientSessionState) { + c.mu.Lock() + c.initializedResult = state.InitializeResult + c.mu.Unlock() + + // Start the persistent SSE listener as soon as we have the initialized + // result. + // + // § 2.2: The client MAY issue an HTTP GET to the MCP endpoint. This can be + // used to open an SSE stream, allowing the server to communicate to the + // client, without the client first sending data via HTTP POST. + // + // We have to wait for initialized, because until we've received + // initialized, we don't know whether the server requires a sessionID. + // + // § 2.5: A server using the Streamable HTTP transport MAY assign a session + // ID at initialization time, by including it in an Mcp-Session-Id header + // on the HTTP response containing the InitializeResult. + go c.handleSSE("hanging GET", nil, true, nil) +} + +// fail handles an asynchronous error while reading. +// +// If err is non-nil, it is terminal, and subsequent (or pending) Reads will +// fail. +// +// If err wraps errSessionMissing, the failure indicates that the session is no +// longer present on the server, and no final DELETE will be performed when +// closing the connection. +func (c *streamableClientConn) fail(err error) { + if err != nil { + c.failOnce.Do(func() { + c._failure = err + close(c.failed) + }) + } +} + +func (c *streamableClientConn) failure() error { + select { + case <-c.failed: + return c._failure + default: + return nil + } } func (c *streamableClientConn) SessionID() string { c.mu.Lock() defer c.mu.Unlock() - return c._sessionID + return c.sessionID } // Read implements the [Connection] interface. -func (s *streamableClientConn) Read(ctx context.Context) (JSONRPCMessage, error) { +func (c *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { + if err := c.failure(); err != nil { + return nil, err + } select { case <-ctx.Done(): return nil, ctx.Err() - case <-s.done: + case <-c.failed: + return nil, c.failure() + case <-c.done: return nil, io.EOF - case data := <-s.incoming: - return jsonrpc2.DecodeMessage(data) + case msg := <-c.incoming: + return msg, nil } } // Write implements the [Connection] interface. -func (s *streamableClientConn) Write(ctx context.Context, msg JSONRPCMessage) error { - s.mu.Lock() - if s.err != nil { - s.mu.Unlock() - return s.err - } - - sessionID := s._sessionID - if sessionID == "" { - // Hold lock for the first request. - defer s.mu.Unlock() - } else { - s.mu.Unlock() - } - - gotSessionID, err := s.postMessage(ctx, sessionID, msg) - if err != nil { - if sessionID != "" { - // unlocked; lock to set err - s.mu.Lock() - defer s.mu.Unlock() - } - if s.err != nil { - s.err = err - } +func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { + if err := c.failure(); err != nil { return err } - if sessionID == "" { - // locked - s._sessionID = gotSessionID + var requestSummary string + switch msg := msg.(type) { + case *jsonrpc.Request: + requestSummary = fmt.Sprintf("sending %q", msg.Method) + case *jsonrpc.Response: + requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) + default: + panic("unreachable") } - return nil -} - -func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg JSONRPCMessage) (string, error) { - data, err := jsonrpc2.EncodeMessage(msg) + data, err := jsonrpc.EncodeMessage(msg) if err != nil { - return "", err + return fmt.Errorf("%s: %v", requestSummary, err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.url, bytes.NewReader(data)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data)) if err != nil { - return "", err - } - if sessionID != "" { - req.Header.Set("Mcp-Session-Id", sessionID) + return err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") + c.setMCPHeaders(req) - resp, err := s.client.Do(req) + resp, err := c.client.Do(req) if err != nil { - return "", err + return fmt.Errorf("%s: %v", requestSummary, err) } + // §2.5.3: "The server MAY terminate the session at any time, after + // which it MUST respond to requests containing that session ID with HTTP + // 404 Not Found." + if resp.StatusCode == http.StatusNotFound { + // Fail the session immediately, rather than relying on jsonrpc2 to fail + // (and close) it, because we want the call to Close to know that this + // session is missing (and therefore not send the DELETE). + err := fmt.Errorf("%s: failed to send: %w", requestSummary, errSessionMissing) + c.fail(err) + resp.Body.Close() + return err + } if resp.StatusCode < 200 || resp.StatusCode >= 300 { - // TODO: do a best effort read of the body here, and format it in the error. resp.Body.Close() - return "", fmt.Errorf("broken session: %v", resp.Status) + return fmt.Errorf("broken session: %v", resp.Status) } - sessionID = resp.Header.Get("Mcp-Session-Id") - if resp.Header.Get("Content-Type") == "text/event-stream" { - go s.handleSSE(resp) - } else { + if sessionID := resp.Header.Get(sessionIDHeader); sessionID != "" { + c.mu.Lock() + hadSessionID := c.sessionID + if hadSessionID == "" { + c.sessionID = sessionID + } + c.mu.Unlock() + if hadSessionID != "" && hadSessionID != sessionID { + resp.Body.Close() + return fmt.Errorf("mismatching session IDs %q and %q", hadSessionID, sessionID) + } + } + if resp.StatusCode == http.StatusNoContent || resp.StatusCode == http.StatusAccepted { + resp.Body.Close() + return nil + } + + switch ct := resp.Header.Get("Content-Type"); ct { + case "application/json": + go c.handleJSON(requestSummary, resp) + + case "text/event-stream": + jsonReq, _ := msg.(*jsonrpc.Request) + go c.handleSSE(requestSummary, resp, false, jsonReq) + + default: resp.Body.Close() + return fmt.Errorf("%s: unsupported content type %q", requestSummary, ct) } - return sessionID, nil + return nil } -func (s *streamableClientConn) handleSSE(resp *http.Response) { - defer resp.Body.Close() +// testAuth controls whether a fake Authorization header is added to outgoing requests. +// TODO: replace with a better mechanism when client-side auth is in place. +var testAuth = false - done := make(chan struct{}) - go func() { - defer close(done) - for evt, err := range scanEvents(resp.Body) { - if err != nil { - // TODO: surface this error; possibly break the stream +func (c *streamableClientConn) setMCPHeaders(req *http.Request) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.initializedResult != nil { + req.Header.Set(protocolVersionHeader, c.initializedResult.ProtocolVersion) + } + if c.sessionID != "" { + req.Header.Set(sessionIDHeader, c.sessionID) + } + if testAuth { + req.Header.Set("Authorization", "Bearer foo") + } +} + +func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Response) { + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + c.fail(fmt.Errorf("%s: failed to read body: %v", requestSummary, err)) + return + } + msg, err := jsonrpc.DecodeMessage(body) + if err != nil { + c.fail(fmt.Errorf("%s: failed to decode response: %v", requestSummary, err)) + return + } + select { + case c.incoming <- msg: + case <-c.done: + // The connection was closed by the client; exit gracefully. + } +} + +// handleSSE manages the lifecycle of an SSE connection. It can be either +// persistent (for the main GET listener) or temporary (for a POST response). +// +// If forReq is set, it is the request that initiated the stream, and the +// stream is complete when we receive its response. +func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) { + resp := initialResp + var lastEventID string + for { + if resp != nil { + eventID, clientClosed := c.processStream(requestSummary, resp, forReq) + lastEventID = eventID + + // If the connection was closed by the client, we're done. + if clientClosed { + return + } + // If the stream has ended, then do not reconnect if the stream is + // temporary (POST initiated SSE). + if lastEventID == "" && !persistent { return } - s.incoming <- evt.data } - }() - select { - case <-s.done: - case <-done: + // The stream was interrupted or ended by the server. Attempt to reconnect. + newResp, err := c.reconnect(lastEventID) + if err != nil { + // All reconnection attempts failed: fail the connection. + c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, err)) + return + } + resp = newResp + if resp.StatusCode == http.StatusMethodNotAllowed && persistent { + // The server doesn't support the hanging GET. + resp.Body.Close() + return + } + // (see equivalent handling in [streamableClientConn.Write]). + if resp.StatusCode == http.StatusNotFound { + c.fail(fmt.Errorf("%s: failed to reconnect: %w", requestSummary, errSessionMissing)) + return + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + resp.Body.Close() + c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, http.StatusText(resp.StatusCode))) + return + } + // Reconnection was successful. Continue the loop with the new response. } } -// Close implements the [Connection] interface. -func (s *streamableClientConn) Close() error { - s.closeOnce.Do(func() { - close(s.done) +// processStream reads from a single response body, sending events to the +// incoming channel. It returns the ID of the last processed event and a flag +// indicating if the connection was closed by the client. If resp is nil, it +// returns "", false. +func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) { + defer resp.Body.Close() + for evt, err := range scanEvents(resp.Body) { + if err != nil { + return lastEventID, false + } + + if evt.ID != "" { + lastEventID = evt.ID + } - req, err := http.NewRequest(http.MethodDelete, s.url, nil) + msg, err := jsonrpc.DecodeMessage(evt.Data) if err != nil { - s.closeErr = err + c.fail(fmt.Errorf("%s: failed to decode event: %v", requestSummary, err)) + return "", true + } + + select { + case c.incoming <- msg: + if jsonResp, ok := msg.(*jsonrpc.Response); ok && forReq != nil { + // TODO: we should never get a response when forReq is nil (the hanging GET). + // We should detect this case, and eliminate the 'persistent' flag arguments. + if jsonResp.ID == forReq.ID { + return "", true + } + } + case <-c.done: + // The connection was closed by the client; exit gracefully. + return "", true + } + } + // The loop finished without an error, indicating the server closed the stream. + return "", false +} + +// reconnect handles the logic of retrying a connection with an exponential +// backoff strategy. It returns a new, valid HTTP response if successful, or +// an error if all retries are exhausted. +func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) { + var finalErr error + + // We can reach the 'reconnect' path through the hanging GET, in which case + // lastEventID will be "". + // + // In this case, we need an initial attempt. + attempt := 0 + if lastEventID != "" { + attempt = 1 + } + + for ; attempt <= c.maxRetries; attempt++ { + select { + case <-c.done: + return nil, fmt.Errorf("connection closed by client during reconnect") + case <-time.After(calculateReconnectDelay(attempt)): + resp, err := c.establishSSE(lastEventID) + if err != nil { + finalErr = err // Store the error and try again. + continue + } + return resp, nil + } + } + // If the loop completes, all retries have failed. + if finalErr != nil { + return nil, fmt.Errorf("connection failed after %d attempts: %w", c.maxRetries, finalErr) + } + return nil, fmt.Errorf("connection failed after %d attempts", c.maxRetries) +} + +// Close implements the [Connection] interface. +func (c *streamableClientConn) Close() error { + c.closeOnce.Do(func() { + // Cancel any hanging network requests. + c.cancel() + close(c.done) + + if errors.Is(c.failure(), errSessionMissing) { + // If the session is missing, no need to delete it. } else { - req.Header.Set("Mcp-Session-Id", s._sessionID) - if _, err := s.client.Do(req); err != nil { - s.closeErr = err + req, err := http.NewRequest(http.MethodDelete, c.url, nil) + if err != nil { + c.closeErr = err + } else { + c.setMCPHeaders(req) + if _, err := c.client.Do(req); err != nil { + c.closeErr = err + } } } }) - return s.closeErr + return c.closeErr +} + +// establishSSE establishes the persistent SSE listening stream. +// It is used for reconnect attempts using the Last-Event-ID header to +// resume a broken stream where it left off. +func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response, error) { + req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil) + if err != nil { + return nil, err + } + c.setMCPHeaders(req) + if lastEventID != "" { + req.Header.Set("Last-Event-ID", lastEventID) + } + req.Header.Set("Accept", "text/event-stream") + + return c.client.Do(req) +} + +// calculateReconnectDelay calculates a delay using exponential backoff with full jitter. +func calculateReconnectDelay(attempt int) time.Duration { + if attempt == 0 { + return 0 + } + // Calculate the exponential backoff using the grow factor. + backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt-1))) + // Cap the backoffDuration at maxDelay. + backoffDuration = min(backoffDuration, reconnectMaxDelay) + + // Use a full jitter using backoffDuration + jitter := rand.N(backoffDuration) + + return backoffDuration + jitter } diff --git a/mcp/streamable_bench_test.go b/mcp/streamable_bench_test.go new file mode 100644 index 00000000..07f81da3 --- /dev/null +++ b/mcp/streamable_bench_test.go @@ -0,0 +1,66 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func BenchmarkStreamableServing(b *testing.B) { + // This benchmark measures how fast we can handle a single tool on a + // streamable server, including tool validation and stream management. + customSchemas := map[any]*jsonschema.Schema{ + Probability(0): {Type: "number", Minimum: jsonschema.Ptr(0.0), Maximum: jsonschema.Ptr(1.0)}, + WeatherType(""): {Type: "string", Enum: []any{Sunny, PartlyCloudy, Cloudy, Rainy, Snowy}}, + } + opts := &jsonschema.ForOptions{TypeSchemas: customSchemas} + in, err := jsonschema.For[WeatherInput](opts) + if err != nil { + b.Fatal(err) + } + out, err := jsonschema.For[WeatherOutput](opts) + if err != nil { + b.Fatal(err) + } + + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + mcp.AddTool(server, &mcp.Tool{ + Name: "weather", + InputSchema: in, + OutputSchema: out, + }, WeatherTool) + + handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return server + }, &mcp.StreamableHTTPOptions{JSONResponse: true}) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + session, err := mcp.NewClient(testImpl, nil).Connect(ctx, &mcp.StreamableClientTransport{Endpoint: httpServer.URL}, nil) + if err != nil { + b.Fatal(err) + } + defer session.Close() + b.ResetTimer() + for range b.N { + if _, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "weather", + Arguments: WeatherInput{ + Location: Location{Name: "somewhere"}, + Days: 7, + }, + }); err != nil { + b.Errorf("CallTool failed: %v", err) + } + } +} diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go new file mode 100644 index 00000000..001d3a64 --- /dev/null +++ b/mcp/streamable_client_test.go @@ -0,0 +1,325 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +type streamableRequestKey struct { + httpMethod string // http method + sessionID string // session ID header + jsonrpcMethod string // jsonrpc method, or "" for non-requests +} + +type header map[string]string + +type streamableResponse struct { + header header // response headers + status int // or http.StatusOK + body string // or "" + optional bool // if set, request need not be sent + wantProtocolVersion string // if "", unchecked + callback func() // if set, called after the request is handled +} + +type fakeResponses map[streamableRequestKey]*streamableResponse + +type fakeStreamableServer struct { + t *testing.T + responses fakeResponses + + callMu sync.Mutex + calls map[streamableRequestKey]int +} + +func (s *fakeStreamableServer) missingRequests() []streamableRequestKey { + s.callMu.Lock() + defer s.callMu.Unlock() + + var unused []streamableRequestKey + for k, resp := range s.responses { + if s.calls[k] == 0 && !resp.optional { + unused = append(unused, k) + } + } + return unused +} + +func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + key := streamableRequestKey{ + httpMethod: req.Method, + sessionID: req.Header.Get(sessionIDHeader), + } + if req.Method == http.MethodPost { + body, err := io.ReadAll(req.Body) + if err != nil { + s.t.Errorf("failed to read body: %v", err) + http.Error(w, "failed to read body", http.StatusInternalServerError) + return + } + msg, err := jsonrpc.DecodeMessage(body) + if err != nil { + s.t.Errorf("invalid body: %v", err) + http.Error(w, "invalid body", http.StatusInternalServerError) + return + } + if r, ok := msg.(*jsonrpc.Request); ok { + key.jsonrpcMethod = r.Method + } + } + + s.callMu.Lock() + if s.calls == nil { + s.calls = make(map[streamableRequestKey]int) + } + s.calls[key]++ + s.callMu.Unlock() + + resp, ok := s.responses[key] + if !ok { + s.t.Errorf("missing response for %v", key) + http.Error(w, "no response", http.StatusInternalServerError) + return + } + if resp.callback != nil { + defer resp.callback() + } + for k, v := range resp.header { + w.Header().Set(k, v) + } + status := resp.status + if status == 0 { + status = http.StatusOK + } + w.WriteHeader(status) + + if v := req.Header.Get(protocolVersionHeader); v != resp.wantProtocolVersion && resp.wantProtocolVersion != "" { + s.t.Errorf("%v: bad protocol version header: got %q, want %q", key, v, resp.wantProtocolVersion) + } + w.Write([]byte(resp.body)) +} + +var ( + initResult = &InitializeResult{ + Capabilities: &ServerCapabilities{ + Completions: &CompletionCapabilities{}, + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, + }, + ProtocolVersion: latestProtocolVersion, + ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, + } + initResp = resp(1, initResult, nil) +) + +func jsonBody(t *testing.T, msg jsonrpc2.Message) string { + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + t.Fatalf("encoding failed: %v", err) + } + return string(data) +} + +func TestStreamableClientTransportLifecycle(t *testing.T) { + ctx := context.Background() + + // The lifecycle test verifies various behavior of the streamable client + // initialization: + // - check that it can handle application/json responses + // - check that it sends the negotiated protocol version + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + optional: true, + wantProtocolVersion: latestProtocolVersion, + }, + {"DELETE", "123", ""}: {}, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + if err := session.Close(); err != nil { + t.Errorf("closing session: %v", err) + } + if missing := fake.missingRequests(); len(missing) > 0 { + t.Errorf("did not receive expected requests: %v", missing) + } + if diff := cmp.Diff(initResult, session.state.InitializeResult); diff != "" { + t.Errorf("mismatch (-want, +got):\n%s", diff) + } +} + +func TestStreamableClientRedundantDelete(t *testing.T) { + ctx := context.Background() + + // The lifecycle test verifies various behavior of the streamable client + // initialization: + // - check that it can handle application/json responses + // - check that it sends the negotiated protocol version + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", ""}: { + status: http.StatusMethodNotAllowed, + optional: true, + }, + {"POST", "123", methodListTools}: { + status: http.StatusNotFound, + }, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + _, err = session.ListTools(ctx, nil) + if err == nil { + t.Errorf("Listing tools: got nil error, want non-nil") + } + _ = session.Wait() // must not hang + if missing := fake.missingRequests(); len(missing) > 0 { + t.Errorf("did not receive expected requests: %v", missing) + } +} + +func TestStreamableClientGETHandling(t *testing.T) { + ctx := context.Background() + + tests := []struct { + status int + wantErrorContaining string + }{ + {http.StatusOK, ""}, + {http.StatusMethodNotAllowed, ""}, + {http.StatusBadRequest, "hanging GET"}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("status=%d", test.status), func(t *testing.T) { + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + status: test.status, + wantProtocolVersion: latestProtocolVersion, + }, + {"POST", "123", methodListTools}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, resp(2, &ListToolsResult{Tools: []*Tool{}}, nil)), + optional: true, + }, + {"DELETE", "123", ""}: {optional: true}, + }, + } + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + + // wait for all required requests to be handled, with exponential + // backoff. + start := time.Now() + delay := 1 * time.Millisecond + for range 10 { + if len(fake.missingRequests()) == 0 { + break + } + time.Sleep(delay) + delay *= 2 + } + if missing := fake.missingRequests(); len(missing) > 0 { + t.Errorf("did not receive expected requests after %s: %v", time.Since(start), missing) + } + + _, err = session.ListTools(ctx, nil) + if (err != nil) != (test.wantErrorContaining != "") { + t.Errorf("After initialization, got error %v, want %v", err, test.wantErrorContaining) + } else if err != nil { + if !strings.Contains(err.Error(), test.wantErrorContaining) { + t.Errorf("After initialization, got error %s, want containing %q", err, test.wantErrorContaining) + } + } + + if err := session.Close(); err != nil { + t.Errorf("closing session: %v", err) + } + }) + } +} diff --git a/mcp/streamable_example_test.go b/mcp/streamable_example_test.go new file mode 100644 index 00000000..430f2745 --- /dev/null +++ b/mcp/streamable_example_test.go @@ -0,0 +1,86 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "bytes" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// !+streamablehandler + +func ExampleStreamableHTTPHandler() { + // Create a new streamable handler, using the same MCP server for every request. + // + // Here, we configure it to serves application/json responses rather than + // text/event-stream, just so the output below doesn't use random event ids. + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.1.0"}, nil) + handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return server + }, &mcp.StreamableHTTPOptions{JSONResponse: true}) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + // The SDK is currently permissive of some missing keys in "params". + resp := mustPostMessage(`{"jsonrpc": "2.0", "id": 1, "method":"initialize", "params": {}}`, httpServer.URL) + fmt.Println(resp) + // Output: + // {"jsonrpc":"2.0","id":1,"result":{"capabilities":{"logging":{}},"protocolVersion":"2025-06-18","serverInfo":{"name":"server","version":"v0.1.0"}}} +} + +// !-streamablehandler + +// !+httpmiddleware + +func ExampleStreamableHTTPHandler_middleware() { + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.1.0"}, nil) + handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { + return server + }, nil) + loggingHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Example debugging; you could also capture the response. + body, err := io.ReadAll(req.Body) + if err != nil { + log.Fatal(err) + } + req.Body.Close() // ignore error + req.Body = io.NopCloser(bytes.NewBuffer(body)) + fmt.Println(req.Method, string(body)) + handler.ServeHTTP(w, req) + }) + httpServer := httptest.NewServer(loggingHandler) + defer httpServer.Close() + + // The SDK is currently permissive of some missing keys in "params". + mustPostMessage(`{"jsonrpc": "2.0", "id": 1, "method":"initialize", "params": {}}`, httpServer.URL) + // Output: + // POST {"jsonrpc": "2.0", "id": 1, "method":"initialize", "params": {}} +} + +// !-httpmiddleware + +func mustPostMessage(msg, url string) string { + req := orFatal(http.NewRequest("POST", url, strings.NewReader(msg))) + req.Header["Content-Type"] = []string{"application/json"} + req.Header["Accept"] = []string{"application/json", "text/event-stream"} + resp := orFatal(http.DefaultClient.Do(req)) + defer resp.Body.Close() + body := orFatal(io.ReadAll(resp.Body)) + return string(body) +} + +func orFatal[T any](t T, err error) T { + if err != nil { + log.Fatal(err) + } + return t +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index a8c916e8..3b967f8f 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -8,20 +8,29 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" + "maps" + "net" "net/http" "net/http/cookiejar" "net/http/httptest" + "net/http/httputil" "net/url" + "sort" "strings" "sync" "sync/atomic" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) func TestStreamableTransports(t *testing.T) { @@ -30,162 +39,650 @@ func TestStreamableTransports(t *testing.T) { ctx := context.Background() - // 1. Create a server with a simple "greet" tool. - server := NewServer("testServer", "v1.0.0", nil) - server.AddTools(NewServerTool("greet", "say hi", sayHi)) + for _, useJSON := range []bool{false, true} { + t.Run(fmt.Sprintf("JSONResponse=%v", useJSON), func(t *testing.T) { + // Create a server with some simple tools. + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + // The "hang" tool checks that context cancellation is propagated. + // It hangs until the context is cancelled. + var ( + start = make(chan struct{}) + cancelled = make(chan struct{}, 1) // don't block the request + ) + hang := func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { + start <- struct{}{} + select { + case <-ctx.Done(): + cancelled <- struct{}{} + case <-time.After(5 * time.Second): + return nil, nil, nil + } + return nil, nil, nil + } + AddTool(server, &Tool{Name: "hang"}, hang) + AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { + // Test that we can make sampling requests during tool handling. + // + // Try this on both the request context and a background context, so + // that messages may be delivered on either the POST or GET connection. + for _, ctx := range map[string]context.Context{ + "request context": ctx, + "background context": context.Background(), + } { + res, err := req.Session.CreateMessage(ctx, &CreateMessageParams{}) + if err != nil { + return nil, nil, err + } + if g, w := res.Model, "aModel"; g != w { + return nil, nil, fmt.Errorf("got %q, want %q", g, w) + } + } + return &CallToolResult{}, nil, nil + }) - // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a - // cookie-checking middleware. - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) - httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - cookie, err := r.Cookie("test-cookie") - if err != nil { - t.Errorf("missing cookie: %v", err) - } else if cookie.Value != "test-value" { - t.Errorf("got cookie %q, want %q", cookie.Value, "test-value") - } - handler.ServeHTTP(w, r) - })) - defer httpServer.Close() + // Start an httptest.Server with the StreamableHTTPHandler, wrapped in a + // cookie-checking middleware. + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + JSONResponse: useJSON, + }) - // 3. Create a client and connect it to the server using our StreamableClientTransport. - // Check that all requests honor a custom client. - jar, err := cookiejar.New(nil) - if err != nil { - t.Fatal(err) + var ( + headerMu sync.Mutex + lastHeader http.Header + ) + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headerMu.Lock() + lastHeader = r.Header + headerMu.Unlock() + cookie, err := r.Cookie("test-cookie") + if err != nil { + t.Errorf("missing cookie: %v", err) + } else if cookie.Value != "test-value" { + t.Errorf("got cookie %q, want %q", cookie.Value, "test-value") + } + handler.ServeHTTP(w, r) + })) + defer httpServer.Close() + + // Create a client and connect it to the server using our StreamableClientTransport. + // Check that all requests honor a custom client. + jar, err := cookiejar.New(nil) + if err != nil { + t.Fatal(err) + } + u, err := url.Parse(httpServer.URL) + if err != nil { + t.Fatal(err) + } + jar.SetCookies(u, []*http.Cookie{{Name: "test-cookie", Value: "test-value"}}) + httpClient := &http.Client{Jar: jar} + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + HTTPClient: httpClient, + } + client := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil + }, + }) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer session.Close() + sid := session.ID() + if sid == "" { + t.Fatalf("empty session ID") + } + if g, w := session.mcpConn.(*streamableClientConn).initializedResult.ProtocolVersion, latestProtocolVersion; g != w { + t.Fatalf("got protocol version %q, want %q", g, w) + } + + // Verify the behavior of various tools. + + // The "greet" tool should just work. + params := &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"Name": "foo"}, + } + got, err := session.CallTool(ctx, params) + if err != nil { + t.Fatalf("CallTool() failed: %v", err) + } + if g := session.ID(); g != sid { + t.Errorf("session ID: got %q, want %q", g, sid) + } + if g, w := lastHeader.Get(protocolVersionHeader), latestProtocolVersion; g != w { + t.Errorf("got protocol version header %q, want %q", g, w) + } + want := &CallToolResult{ + Content: []Content{&TextContent{Text: "hi foo"}}, + } + if diff := cmp.Diff(want, got, ctrCmpOpts...); diff != "" { + t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff) + } + + // The "hang" tool should be cancellable. + ctx2, cancel := context.WithCancel(context.Background()) + go session.CallTool(ctx2, &CallToolParams{Name: "hang"}) + <-start + cancel() + select { + case <-cancelled: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for cancellation") + } + + // The "sampling" tool should be able to issue sampling requests during + // tool operation. + result, err := session.CallTool(ctx, &CallToolParams{ + Name: "sample", + Arguments: map[string]any{}, + }) + if err != nil { + t.Fatal(err) + } + if result.IsError { + t.Fatalf("tool failed: %s", result.Content[0].(*TextContent).Text) + } + }) } - u, err := url.Parse(httpServer.URL) +} + +func TestStreamableServerShutdown(t *testing.T) { + ctx := context.Background() + + // This test checks that closing the streamable HTTP server actually results + // in client session termination, provided one of following holds: + // 1. The server is stateful, and therefore the hanging GET fails the connection. + // 2. The server is stateless, and the client uses a KeepAlive. + tests := []struct { + name string + stateless, keepalive bool + }{ + {"stateful", false, false}, + {"stateless with keepalive", true, true}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + server := NewServer(testImpl, nil) + // Add a tool, just so we can check things are working. + AddTool(server, &Tool{Name: "greet"}, sayHi) + + handler := NewStreamableHTTPHandler( + func(req *http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: test.stateless}) + + // When we shut down the server, we need to explicitly close ongoing + // connections. Otherwise, the hanging GET may never terminate. + httpServer := httptest.NewUnstartedServer(handler) + httpServer.Config.RegisterOnShutdown(func() { + for session := range server.Sessions() { + session.Close() + } + }) + httpServer.Start() + defer httpServer.Close() + + // Connect and run a tool. + var opts ClientOptions + if test.keepalive { + opts.KeepAlive = 50 * time.Millisecond + } + client := NewClient(testImpl, &opts) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{ + Endpoint: httpServer.URL, + MaxRetries: -1, // avoid slow tests during exponential retries + }, nil) + if err != nil { + t.Fatal(err) + } + defer clientSession.Close() + + params := &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"Name": "foo"}, + } + // Verify that we can call a tool. + if _, err := clientSession.CallTool(ctx, params); err != nil { + t.Fatalf("CallTool() failed: %v", err) + } + + // Shut down the server. Sessions should terminate. + go func() { + if err := httpServer.Config.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Errorf("closing http server: %v", err) + } + }() + + // Wait may return an error (after all, the connection failed), but it + // should not hang. + t.Log("Client waiting") + _ = clientSession.Wait() + }) + } +} + +// TestClientReplay verifies that the client can recover from a mid-stream +// network failure and receive replayed messages (if replay is configured). It +// uses a proxy that is killed and restarted to simulate a recoverable network +// outage. +func TestClientReplay(t *testing.T) { + for _, test := range []clientReplayTest{ + {"default", 0, true}, + {"no retries", -1, false}, + } { + t.Run(test.name, func(t *testing.T) { + testClientReplay(t, test) + }) + } +} + +type clientReplayTest struct { + name string + maxRetries int + wantRecovered bool +} + +func testClientReplay(t *testing.T, test clientReplayTest) { + notifications := make(chan string) + // Configure the real MCP server. + server := NewServer(testImpl, nil) + + // Use a channel to synchronize the server's message sending with the test's + // proxy-killing action. + serverReadyToKillProxy := make(chan struct{}) + serverClosed := make(chan struct{}) + AddTool(server, &Tool{Name: "multiMessageTool", InputSchema: &jsonschema.Schema{Type: "object"}}, + func(ctx context.Context, req *CallToolRequest, args map[string]any) (*CallToolResult, any, error) { + // Send one message to the request context, and another to a background + // context (which will end up on the hanging GET). + + bgCtx := context.Background() + req.Session.NotifyProgress(ctx, &ProgressNotificationParams{Message: "msg1"}) + req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"}) + + // Signal the test that it can now kill the proxy. + close(serverReadyToKillProxy) + <-serverClosed + + // These messages should be queued for replay by the server after + // the client's connection drops. + req.Session.NotifyProgress(ctx, &ProgressNotificationParams{Message: "msg3"}) + req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"}) + return new(CallToolResult), nil, nil + }) + + realServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) + defer realServer.Close() + realServerURL, err := url.Parse(realServer.URL) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to parse real server URL: %v", err) } - jar.SetCookies(u, []*http.Cookie{{Name: "test-cookie", Value: "test-value"}}) - httpClient := &http.Client{Jar: jar} - transport := NewStreamableClientTransport(httpServer.URL, &StreamableClientTransportOptions{ - HTTPClient: httpClient, + + // Configure a proxy that sits between the client and the real server. + proxyHandler := httputil.NewSingleHostReverseProxy(realServerURL) + proxy := httptest.NewServer(proxyHandler) + proxyAddr := proxy.Listener.Addr().String() // Get the address to restart it later. + + // Configure the client to connect to the proxy with default options. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + client := NewClient(testImpl, &ClientOptions{ + ProgressNotificationHandler: func(ctx context.Context, req *ProgressNotificationClientRequest) { + notifications <- req.Params.Message + }, }) - client := NewClient("testClient", "v1.0.0", nil) - session, err := client.Connect(ctx, transport) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{ + Endpoint: proxy.URL, + MaxRetries: test.maxRetries, + }, nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) } - defer session.Close() - sid := session.ID() - if sid == "" { - t.Error("empty session ID") + defer clientSession.Close() + + var ( + wg sync.WaitGroup + callErr error + ) + wg.Add(1) + go func() { + defer wg.Done() + _, callErr = clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"}) + }() + + select { + case <-serverReadyToKillProxy: + // Server has sent the first two messages and is paused. + case <-ctx.Done(): + t.Fatalf("Context timed out before server was ready to kill proxy") } - // 4. The client calls the "greet" tool. - params := &CallToolParams{ - Name: "greet", - Arguments: map[string]any{"name": "streamy"}, + + // We should always get the first two notifications. + msgs := readNotifications(t, ctx, notifications, 2) + sort.Strings(msgs) // notifications may arrive in either order + want := []string{"msg1", "msg2"} + if diff := cmp.Diff(want, msgs); diff != "" { + t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) } - got, err := session.CallTool(ctx, params) + + // Simulate a total network failure by closing the proxy. + t.Log("--- Killing proxy to simulate network failure ---") + proxy.CloseClientConnections() + proxy.Close() + close(serverClosed) + + // Simulate network recovery by restarting the proxy on the same address. + t.Logf("--- Restarting proxy on %s ---", proxyAddr) + listener, err := net.Listen("tcp", proxyAddr) if err != nil { - t.Fatalf("CallTool() failed: %v", err) + t.Fatalf("Failed to listen on proxy address: %v", err) } - if g := session.ID(); g != sid { - t.Errorf("session ID: got %q, want %q", g, sid) + + restartedProxy := &http.Server{Handler: proxyHandler} + go restartedProxy.Serve(listener) + defer restartedProxy.Close() + + wg.Wait() + + if test.wantRecovered { + // If we've recovered, we should get all 4 notifications and the tool call + // should have succeeded. + msgs := readNotifications(t, ctx, notifications, 2) + sort.Strings(msgs) + want := []string{"msg3", "msg4"} + if diff := cmp.Diff(want, msgs); diff != "" { + t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) + } + if callErr != nil { + t.Errorf("CallTool failed unexpectedly: %v", err) + } + } else { + // Otherwise, the call should fail. + if callErr == nil { + t.Errorf("CallTool succeeded unexpectedly") + } } +} + +func TestServerTransportCleanup(t *testing.T) { + nClient := 3 + + var mu sync.Mutex + var id int = -1 // session id starting from "0", "1", "2"... + chans := make(map[string]chan struct{}, nClient) - // 5. Verify that the correct response is received. - want := &CallToolResult{ - Content: []Content{ - &TextContent{Text: "hi streamy"}, + server := NewServer(testImpl, &ServerOptions{ + KeepAlive: 10 * time.Millisecond, + GetSessionID: func() string { + mu.Lock() + defer mu.Unlock() + id++ + if id == nClient { + t.Errorf("creating more than %v session", nClient) + } + chans[fmt.Sprint(id)] = make(chan struct{}, 1) + return fmt.Sprint(id) }, + }) + + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil) + handler.onTransportDeletion = func(sessionID string) { + chans[sessionID] <- struct{}{} } - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff) - } -} -func TestStreamableServerTransport(t *testing.T) { - // This test checks detailed behavior of the streamable server transport, by - // faking the behavior of a streamable client using a sequence of HTTP - // requests. + httpServer := httptest.NewServer(handler) + defer httpServer.Close() - // A step is a single step in the tests below, consisting of a request payload - // and expected response. - type step struct { - // If OnRequest is > 0, this step only executes after a request with the - // given ID is received. - // - // All OnRequest steps must occur before the step that creates the request. - // - // To avoid tests hanging when there's a bug, it's expected that this - // request is received in the course of a *synchronous* request to the - // server (otherwise, we wouldn't be able to terminate the test without - // analyzing a dependency graph). - OnRequest int64 - // If set, Async causes the step to run asynchronously to other steps. - // Redundant with OnRequest: all OnRequest steps are asynchronous. - Async bool - - Method string // HTTP request method - Send []JSONRPCMessage // messages to send - CloseAfter int // if nonzero, close after receiving this many messages - StatusCode int // expected status code - Recv []JSONRPCMessage // expected messages to receive - } - - // JSON-RPC message constructors. - req := func(id int64, method string, params any) *JSONRPCRequest { - r := &JSONRPCRequest{ - Method: method, - Params: mustMarshal(t, params), + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Spin up clients connect to the same server but refuse to ping request. + for range nClient { + client := NewClient(testImpl, nil) + pingMiddleware := func(next MethodHandler) MethodHandler { + return func( + ctx context.Context, + method string, + req Request, + ) (Result, error) { + if method == "ping" { + return &emptyResult{}, errors.New("ping error") + } + return next(ctx, method, req) + } } - if id > 0 { - r.ID = jsonrpc2.Int64ID(id) + client.AddReceivingMiddleware(pingMiddleware) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) } - return r + defer clientSession.Close() } - resp := func(id int64, result any, err error) *JSONRPCResponse { - return &JSONRPCResponse{ - ID: jsonrpc2.Int64ID(id), - Result: mustMarshal(t, result), - Error: err, + + for _, ch := range chans { + select { + case <-ctx.Done(): + t.Errorf("did not capture transport deletion event from all session in 10 seconds") + case <-ch: // Received transport deletion signal of this session } } + handler.mu.Lock() + if len(handler.transports) != 0 { + t.Errorf("want empty transports map, find %v entries from handler's transports map", len(handler.transports)) + } + handler.mu.Unlock() +} + +// TestServerInitiatedSSE verifies that the persistent SSE connection remains +// open and can receive server-initiated events. +func TestServerInitiatedSSE(t *testing.T) { + notifications := make(chan string) + server := NewServer(testImpl, nil) + + httpServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) + defer httpServer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + client := NewClient(testImpl, &ClientOptions{ + ToolListChangedHandler: func(context.Context, *ToolListChangedRequest) { + notifications <- "toolListChanged" + }, + }) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer clientSession.Close() + AddTool(server, &Tool{Name: "testTool", InputSchema: &jsonschema.Schema{Type: "object"}}, + func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) { + return &CallToolResult{}, nil, nil + }) + receivedNotifications := readNotifications(t, ctx, notifications, 1) + wantReceived := []string{"toolListChanged"} + if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" { + t.Errorf("Received notifications mismatch (-want +got):\n%s", diff) + } +} + +// Helper to read a specific number of notifications. +func readNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string { + t.Helper() + var collectedNotifications []string + for { + select { + case n := <-notifications: + collectedNotifications = append(collectedNotifications, n) + if len(collectedNotifications) == count { + return collectedNotifications + } + case <-ctx.Done(): + if len(collectedNotifications) != count { + t.Fatalf("readProgressNotifications(): did not receive expected notifications, got %d, want %d", len(collectedNotifications), count) + } + return collectedNotifications + } + } +} + +// JSON-RPC message constructors. +func req(id int64, method string, params any) *jsonrpc.Request { + r := &jsonrpc.Request{ + Method: method, + Params: mustMarshal(params), + } + if id > 0 { + r.ID = jsonrpc2.Int64ID(id) + } + return r +} + +func resp(id int64, result any, err error) *jsonrpc.Response { + return &jsonrpc.Response{ + ID: jsonrpc2.Int64ID(id), + Result: mustMarshal(result), + Error: err, + } +} + +func TestStreamableServerTransport(t *testing.T) { + // This test checks detailed behavior of the streamable server transport, by + // faking the behavior of a streamable client using a sequence of HTTP + // requests. + // Predefined steps, to avoid repetition below. - initReq := req(1, "initialize", &InitializeParams{}) + initReq := req(1, methodInitialize, &InitializeParams{}) initResp := resp(1, &InitializeResult{ - Capabilities: &serverCapabilities{ - Completions: &completionCapabilities{}, - Logging: &loggingCapabilities{}, - Prompts: &promptCapabilities{ListChanged: true}, - Resources: &resourceCapabilities{ListChanged: true}, - Tools: &toolCapabilities{ListChanged: true}, + Capabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, }, - ProtocolVersion: "2025-03-26", - ServerInfo: &implementation{Name: "testServer", Version: "v1.0.0"}, + ProtocolVersion: latestProtocolVersion, + ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, }, nil) - initializedMsg := req(0, "initialized", &InitializedParams{}) - initialize := step{ - Method: "POST", - Send: []JSONRPCMessage{initReq}, - StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{initResp}, + initializedMsg := req(0, notificationInitialized, &InitializedParams{}) + initialize := streamableRequest{ + method: "POST", + messages: []jsonrpc.Message{initReq}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{initResp}, + wantSessionID: true, } - initialized := step{ - Method: "POST", - Send: []JSONRPCMessage{initializedMsg}, - StatusCode: http.StatusAccepted, + initialized := streamableRequest{ + method: "POST", + messages: []jsonrpc.Message{initializedMsg}, + wantStatusCode: http.StatusAccepted, } tests := []struct { - name string - tool func(*testing.T, context.Context, *ServerSession) - steps []step + name string + tool func(*testing.T, context.Context, *ServerSession) + requests []streamableRequest // http requests }{ { name: "basic", - steps: []step{ + requests: []streamableRequest{ + initialize, + initialized, + { + method: "POST", + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, + }, + }, + }, + { + name: "accept headers", + requests: []streamableRequest{ + initialize, + initialized, + // Test various accept headers. + { + method: "POST", + headers: http.Header{"Accept": {"text/plain", "application/*"}}, + messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusBadRequest, // missing text/event-stream + }, + { + method: "POST", + headers: http.Header{"Accept": {"text/event-stream"}}, + messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusBadRequest, // missing application/json + }, + { + method: "POST", + headers: http.Header{"Accept": {"text/plain", "*/*"}}, + messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{Content: []Content{}}, nil)}, + }, + { + method: "POST", + headers: http.Header{"Accept": {"text/*, application/*"}}, + messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{Content: []Content{}}, nil)}, + }, + }, + }, + { + name: "protocol version headers", + requests: []streamableRequest{ initialize, initialized, { - Method: "POST", - Send: []JSONRPCMessage{req(2, "tools/call", &CallToolParams{Name: "tool"})}, - StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{resp(2, &CallToolResult{}, nil)}, + method: "POST", + headers: http.Header{"mcp-protocol-version": {"2025-01-01"}}, // an invalid protocol version + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "2025-03-26", // a supported version + wantSessionID: false, // could be true, but shouldn't matter + }, + }, + }, + { + name: "batch rejected on 2025-06-18", + requests: []streamableRequest{ + initialize, + initialized, + { + method: "POST", + // Explicitly set the protocol version header + headers: http.Header{"MCP-Protocol-Version": {"2025-06-18"}}, + // Two messages => batch. Expect reject. + messages: []jsonrpc.Message{ + req(101, "tools/call", &CallToolParams{Name: "tool"}), + req(102, "tools/call", &CallToolParams{Name: "tool"}), + }, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "batch", + }, + }, + }, + { + name: "batch accepted on 2025-03-26", + requests: []streamableRequest{ + initialize, + initialized, + { + method: "POST", + headers: http.Header{"MCP-Protocol-Version": {"2025-03-26"}}, + // Two messages => batch. Expect OK with two responses in order. + messages: []jsonrpc.Message{ + req(201, "tools/call", &CallToolParams{Name: "tool"}), + req(202, "tools/call", &CallToolParams{Name: "tool"}), + }, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{ + resp(201, &CallToolResult{Content: []Content{}}, nil), + resp(202, &CallToolResult{Content: []Content{}}, nil), + }, }, }, }, @@ -197,18 +694,18 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Notify failed: %v", err) } }, - steps: []step{ + requests: []streamableRequest{ initialize, initialized, { - Method: "POST", - Send: []JSONRPCMessage{ + method: "POST", + messages: []jsonrpc.Message{ req(2, "tools/call", &CallToolParams{Name: "tool"}), }, - StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{ + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{ req(0, "notifications/progress", &ProgressNotificationParams{}), - resp(2, &CallToolResult{}, nil), + resp(2, &CallToolResult{Content: []Content{}}, nil), }, }, }, @@ -221,36 +718,36 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Call failed: %v", err) } }, - steps: []step{ + requests: []streamableRequest{ initialize, initialized, { - Method: "POST", - OnRequest: 1, - Send: []JSONRPCMessage{ + method: "POST", + onRequest: 1, + messages: []jsonrpc.Message{ resp(1, &ListRootsResult{}, nil), }, - StatusCode: http.StatusAccepted, + wantStatusCode: http.StatusAccepted, }, { - Method: "POST", - Send: []JSONRPCMessage{ + method: "POST", + messages: []jsonrpc.Message{ req(2, "tools/call", &CallToolParams{Name: "tool"}), }, - StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{ + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{ req(1, "roots/list", &ListRootsParams{}), - resp(2, &CallToolResult{}, nil), + resp(2, &CallToolResult{Content: []Content{}}, nil), }, }, }, }, { name: "background", - tool: func(t *testing.T, ctx context.Context, ss *ServerSession) { + tool: func(t *testing.T, _ context.Context, ss *ServerSession) { // Perform operations on a background context, and ensure the client // receives it. - ctx = context.Background() + ctx := context.Background() if err := ss.NotifyProgress(ctx, &ProgressNotificationParams{}); err != nil { t.Errorf("Notify failed: %v", err) } @@ -262,55 +759,72 @@ func TestStreamableServerTransport(t *testing.T) { t.Errorf("Notify failed: %v", err) } }, - steps: []step{ + requests: []streamableRequest{ initialize, initialized, { - Method: "POST", - OnRequest: 1, - Send: []JSONRPCMessage{ + method: "POST", + onRequest: 1, + messages: []jsonrpc.Message{ resp(1, &ListRootsResult{}, nil), }, - StatusCode: http.StatusAccepted, + wantStatusCode: http.StatusAccepted, }, { - Method: "GET", - Async: true, - StatusCode: http.StatusOK, - CloseAfter: 2, - Recv: []JSONRPCMessage{ + method: "GET", + async: true, + wantStatusCode: http.StatusOK, + closeAfter: 2, + wantMessages: []jsonrpc.Message{ req(0, "notifications/progress", &ProgressNotificationParams{}), req(1, "roots/list", &ListRootsParams{}), }, }, { - Method: "POST", - Send: []JSONRPCMessage{ + method: "POST", + messages: []jsonrpc.Message{ req(2, "tools/call", &CallToolParams{Name: "tool"}), }, - StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{ - resp(2, &CallToolResult{}, nil), + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{ + resp(2, &CallToolResult{Content: []Content{}}, nil), }, }, + { + method: "DELETE", + wantStatusCode: http.StatusNoContent, + // Delete request expects 204 No Content with empty body. So override + // the default "accept: application/json, text/event-stream" header. + headers: map[string][]string{"Accept": nil}, + }, }, }, { name: "errors", - steps: []step{ + requests: []streamableRequest{ { - Method: "PUT", - StatusCode: http.StatusMethodNotAllowed, + method: "PUT", + wantStatusCode: http.StatusMethodNotAllowed, }, { - Method: "DELETE", - StatusCode: http.StatusBadRequest, + method: "DELETE", + wantStatusCode: http.StatusBadRequest, }, { - Method: "POST", - Send: []JSONRPCMessage{req(2, "tools/call", &CallToolParams{Name: "tool"})}, - StatusCode: http.StatusOK, - Recv: []JSONRPCMessage{resp(2, nil, &jsonrpc2.WireError{ + method: "POST", + messages: []jsonrpc.Message{req(1, "notamethod", nil)}, + wantStatusCode: http.StatusBadRequest, // notamethod is an invalid method + }, + { + method: "POST", + messages: []jsonrpc.Message{req(0, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusBadRequest, // tools/call must have an ID + }, + { + method: "POST", + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(2, nil, &jsonrpc2.WireError{ Message: `method "tools/call" is invalid during session initialization`, })}, }, @@ -322,132 +836,177 @@ func TestStreamableServerTransport(t *testing.T) { t.Run(test.name, func(t *testing.T) { // Create a server containing a single tool, which runs the test tool // behavior, if any. - server := NewServer("testServer", "v1.0.0", nil) - tool := NewServerTool("tool", "test tool", func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[any], error) { - if test.tool != nil { - test.tool(t, ctx, ss) - } - return &CallToolResultFor[any]{}, nil - }) - server.AddTools(tool) + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + server.AddTool( + &Tool{Name: "tool", InputSchema: &jsonschema.Schema{Type: "object"}}, + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + if test.tool != nil { + test.tool(t, ctx, req.Session) + } + return &CallToolResult{}, nil + }) // Start the streamable handler. handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) defer handler.closeAll() - httpServer := httptest.NewServer(handler) - defer httpServer.Close() + testStreamableHandler(t, handler, test.requests) + }) + } +} - // blocks records request blocks by JSONRPC ID. - // - // When an OnRequest step is encountered, it waits on the corresponding - // block. When a request with that ID is received, the block is closed. - var mu sync.Mutex - blocks := make(map[int64]chan struct{}) - for _, step := range test.steps { - if step.OnRequest > 0 { - blocks[step.OnRequest] = make(chan struct{}) - } +func testStreamableHandler(t *testing.T, handler http.Handler, requests []streamableRequest) { + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + // blocks records request blocks by jsonrpc. ID. + // + // When an OnRequest step is encountered, it waits on the corresponding + // block. When a request with that ID is received, the block is closed. + var mu sync.Mutex + blocks := make(map[int64]chan struct{}) + for _, req := range requests { + if req.onRequest > 0 { + blocks[req.onRequest] = make(chan struct{}) + } + } + + // signal when all synchronous requests have executed, so we can fail + // async requests that are blocked. + syncRequestsDone := make(chan struct{}) + + // To avoid complicated accounting for session ID, just set the first + // non-empty session ID from a response. + var sessionID atomic.Value + sessionID.Store("") + + // doStep executes a single step. + doStep := func(t *testing.T, i int, request streamableRequest) { + if request.onRequest > 0 { + // Block the step until we've received the server->client request. + mu.Lock() + block := blocks[request.onRequest] + mu.Unlock() + select { + case <-block: + case <-syncRequestsDone: + t.Errorf("after all sync requests are complete, request still blocked on %d", request.onRequest) + return } + } - // signal when all synchronous requests have executed, so we can fail - // async requests that are blocked. - syncRequestsDone := make(chan struct{}) + // Collect messages received during this request, unblock other steps + // when requests are received. + var got []jsonrpc.Message + out := make(chan jsonrpc.Message) + // Cancel the step if we encounter a request that isn't going to be + // handled. + // + // Also, add a timeout (hopefully generous). + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - // To avoid complicated accounting for session ID, just set the first - // non-empty session ID from a response. - var sessionID atomic.Value - sessionID.Store("") + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() - // doStep executes a single step. - doStep := func(t *testing.T, step step) { - if step.OnRequest > 0 { - // Block the step until we've received the server->client request. + for m := range out { + if req, ok := m.(*jsonrpc.Request); ok && req.IsCall() { + // Encountered a server->client request. We should have a + // response queued. Otherwise, we may deadlock. mu.Lock() - block := blocks[step.OnRequest] - mu.Unlock() - select { - case <-block: - case <-syncRequestsDone: - t.Errorf("after all sync requests are complete, request still blocked on %d", step.OnRequest) - return + if block, ok := blocks[req.ID.Raw().(int64)]; ok { + close(block) + } else { + t.Errorf("no queued response for %v", req.ID) + cancel() } + mu.Unlock() } + got = append(got, m) + if request.closeAfter > 0 && len(got) == request.closeAfter { + cancel() + } + } + }() - // Collect messages received during this request, unblock other steps - // when requests are received. - var got []JSONRPCMessage - out := make(chan JSONRPCMessage) - // Cancel the step if we encounter a request that isn't going to be - // handled. - ctx, cancel := context.WithCancel(context.Background()) - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - - for m := range out { - if req, ok := m.(*JSONRPCRequest); ok && req.ID.IsValid() { - // Encountered a server->client request. We should have a - // response queued. Otherwise, we may deadlock. - mu.Lock() - if block, ok := blocks[req.ID.Raw().(int64)]; ok { - close(block) - } else { - t.Errorf("no queued response for %v", req.ID) - cancel() - } - mu.Unlock() - } - got = append(got, m) - if step.CloseAfter > 0 && len(got) == step.CloseAfter { - cancel() - } - } - }() - - gotSessionID, gotStatusCode, err := streamingRequest(ctx, - httpServer.URL, sessionID.Load().(string), step.Method, step.Send, out) + gotSessionID, gotStatusCode, gotBody, err := request.do(ctx, httpServer.URL, sessionID.Load().(string), out) - // Don't fail on cancelled requests: error (if any) is handled - // elsewhere. - if err != nil && ctx.Err() == nil { - t.Fatal(err) - } + // Don't fail on cancelled requests: error (if any) is handled + // elsewhere. + if err != nil && ctx.Err() == nil { + t.Fatal(err) + } - if gotStatusCode != step.StatusCode { - t.Errorf("got status %d, want %d", gotStatusCode, step.StatusCode) - } - wg.Wait() + if gotStatusCode != request.wantStatusCode { + t.Errorf("request #%d: got status %d, want %d", i, gotStatusCode, request.wantStatusCode) + } + if got := gotSessionID != ""; got != request.wantSessionID { + t.Errorf("request #%d: got session id: %t, want %t", i, got, request.wantSessionID) + } + wg.Wait() - transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id JSONRPCID) any { return id.Raw() }) - if diff := cmp.Diff(step.Recv, got, transform); diff != "" { - t.Errorf("received unexpected messages (-want +got):\n%s", diff) - } - sessionID.CompareAndSwap("", gotSessionID) - } - - var wg sync.WaitGroup - for _, step := range test.steps { - if step.Async || step.OnRequest > 0 { - wg.Add(1) - go func() { - defer wg.Done() - doStep(t, step) - }() - } else { - doStep(t, step) - } + if request.wantBodyContaining != "" { + body := string(gotBody) + if !strings.Contains(body, request.wantBodyContaining) { + t.Errorf("body does not contain %q:\n%s", request.wantBodyContaining, body) } + } else { + transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) + if diff := cmp.Diff(request.wantMessages, got, transform); diff != "" { + t.Errorf("request #%d: received unexpected messages (-want +got):\n%s", i, diff) + } + } + sessionID.CompareAndSwap("", gotSessionID) + } - // Fail any blocked responses if they weren't needed by a synchronous - // request. - close(syncRequestsDone) - - wg.Wait() - }) + var wg sync.WaitGroup + for i, request := range requests { + if request.async || request.onRequest > 0 { + wg.Add(1) + go func() { + defer wg.Done() + doStep(t, i, request) + }() + } else { + doStep(t, i, request) + } } + + // Fail any blocked responses if they weren't needed by a synchronous + // request. + close(syncRequestsDone) + + wg.Wait() +} + +// A streamableRequest describes a single streamable HTTP request, consisting +// of a request payload and expected response. +type streamableRequest struct { + // If onRequest is > 0, this step only executes after a request with the + // given ID is received. + // + // All onRequest steps must occur before the step that creates the request. + // + // To avoid tests hanging when there's a bug, it's expected that this + // request is received in the course of a *synchronous* request to the + // server (otherwise, we wouldn't be able to terminate the test without + // analyzing a dependency graph). + onRequest int64 + // If set, async causes the step to run asynchronously to other steps. + // Redundant with OnRequest: all OnRequest steps are asynchronous. + async bool + + // Request attributes + method string // HTTP request method (required) + headers http.Header // additional headers to set, overlaid on top of the default headers + messages []jsonrpc.Message // messages to send + + closeAfter int // if nonzero, close after receiving this many messages + wantStatusCode int // expected status code + wantBodyContaining string // if set, expect the response body to contain this text; overrides wantMessages + wantMessages []jsonrpc.Message // expected messages to receive; ignored if wantBodyContaining is set + wantSessionID bool // whether or not a session ID is expected in the response } // streamingRequest makes a request to the given streamable server with the @@ -463,112 +1022,143 @@ func TestStreamableServerTransport(t *testing.T) { // Returns the sessionID and http status code from the response. If an error is // returned, sessionID and status code may still be set if the error occurs // after the response headers have been received. -func streamingRequest(ctx context.Context, serverURL, sessionID, method string, in []JSONRPCMessage, out chan<- JSONRPCMessage) (string, int, error) { +func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, out chan<- jsonrpc.Message) (string, int, []byte, error) { defer close(out) var body []byte - if len(in) == 1 { - data, err := jsonrpc2.EncodeMessage(in[0]) + if len(s.messages) == 1 { + data, err := jsonrpc2.EncodeMessage(s.messages[0]) if err != nil { - return "", 0, fmt.Errorf("encoding message: %w", err) + return "", 0, nil, fmt.Errorf("encoding message: %w", err) } body = data } else { var rawMsgs []json.RawMessage - for _, msg := range in { + for _, msg := range s.messages { data, err := jsonrpc2.EncodeMessage(msg) if err != nil { - return "", 0, fmt.Errorf("encoding message: %w", err) + return "", 0, nil, fmt.Errorf("encoding message: %w", err) } rawMsgs = append(rawMsgs, data) } data, err := json.Marshal(rawMsgs) if err != nil { - return "", 0, fmt.Errorf("marshaling batch: %w", err) + return "", 0, nil, fmt.Errorf("marshaling batch: %w", err) } body = data } - req, err := http.NewRequestWithContext(ctx, method, serverURL, bytes.NewReader(body)) + req, err := http.NewRequestWithContext(ctx, s.method, serverURL, bytes.NewReader(body)) if err != nil { - return "", 0, fmt.Errorf("creating request: %w", err) + return "", 0, nil, fmt.Errorf("creating request: %w", err) } if sessionID != "" { - req.Header.Set("Mcp-Session-Id", sessionID) + req.Header.Set(sessionIDHeader, sessionID) } req.Header.Set("Content-Type", "application/json") - req.Header.Add("Accept", "text/plain") // ensure multiple accept headers are allowed - req.Header.Add("Accept", "application/json, text/event-stream") + req.Header.Set("Accept", "application/json, text/event-stream") + maps.Copy(req.Header, s.headers) resp, err := http.DefaultClient.Do(req) if err != nil { - return "", 0, fmt.Errorf("request failed: %v", err) + return "", 0, nil, fmt.Errorf("request failed: %v", err) } defer resp.Body.Close() - newSessionID := resp.Header.Get("Mcp-Session-Id") + newSessionID := resp.Header.Get(sessionIDHeader) - if strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") { - for evt, err := range scanEvents(resp.Body) { + contentType := resp.Header.Get("Content-Type") + var respBody []byte + if strings.HasPrefix(contentType, "text/event-stream") { + r := readerInto{resp.Body, new(bytes.Buffer)} + for evt, err := range scanEvents(r) { if err != nil { - return newSessionID, resp.StatusCode, fmt.Errorf("reading events: %v", err) + return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading events: %v", err) } // TODO(rfindley): do we need to check evt.name? // Does the MCP spec say anything about this? - msg, err := jsonrpc2.DecodeMessage(evt.data) + msg, err := jsonrpc2.DecodeMessage(evt.Data) if err != nil { - return newSessionID, resp.StatusCode, fmt.Errorf("decoding message: %w", err) + return newSessionID, resp.StatusCode, nil, fmt.Errorf("decoding message: %w", err) } out <- msg } - } else if strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { + respBody = r.w.Bytes() + } else if strings.HasPrefix(contentType, "application/json") { data, err := io.ReadAll(resp.Body) if err != nil { - return newSessionID, resp.StatusCode, fmt.Errorf("reading json body: %w", err) + return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading json body: %w", err) } + respBody = data msg, err := jsonrpc2.DecodeMessage(data) if err != nil { - return newSessionID, resp.StatusCode, fmt.Errorf("decoding message: %w", err) + return newSessionID, resp.StatusCode, nil, fmt.Errorf("decoding message: %w", err) } out <- msg + } else { + respBody, err = io.ReadAll(resp.Body) + if err != nil { + return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading response: %v", err) + } } - return newSessionID, resp.StatusCode, nil + return newSessionID, resp.StatusCode, respBody, nil +} + +// readerInto is an io.Reader that writes any bytes read from r into w. +type readerInto struct { + r io.Reader + w *bytes.Buffer } -func mustMarshal(t *testing.T, v any) json.RawMessage { +// Read implements io.Reader. +func (r readerInto) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + if err == nil || err == io.EOF { + n2, err2 := r.w.Write(p[:n]) + if err2 != nil { + return n, fmt.Errorf("failed to write: %v", err) + } + if n2 != n { + return n, fmt.Errorf("short write: %d != %d", n2, n) + } + } + return n, err +} + +func mustMarshal(v any) json.RawMessage { if v == nil { return nil } - t.Helper() data, err := json.Marshal(v) if err != nil { - t.Fatal(err) + panic(err) } return data } func TestEventID(t *testing.T) { tests := []struct { - sid streamID + sid string idx int }{ - {0, 0}, - {0, 1}, - {1, 0}, - {1, 1}, - {1234, 5678}, + {"0", 0}, + {"0", 1}, + {"1", 0}, + {"1", 1}, + {"", 1}, + {"1234", 5678}, } for _, test := range tests { - t.Run(fmt.Sprintf("%d_%d", test.sid, test.idx), func(t *testing.T) { + t.Run(fmt.Sprintf("%s_%d", test.sid, test.idx), func(t *testing.T) { eventID := formatEventID(test.sid, test.idx) gotSID, gotIdx, ok := parseEventID(eventID) if !ok { t.Fatalf("parseEventID(%q) failed, want ok", eventID) } if gotSID != test.sid || gotIdx != test.idx { - t.Errorf("parseEventID(%q) = %d, %d, want %d, %d", eventID, gotSID, gotIdx, test.sid, test.idx) + t.Errorf("parseEventID(%q) = %s, %d, want %s, %d", eventID, gotSID, gotIdx, test.sid, test.idx) } }) } @@ -577,10 +1167,7 @@ func TestEventID(t *testing.T) { "", "_", "1_", - "_1", - "a_1", "1_a", - "-1_1", "1_-1", } @@ -592,3 +1179,254 @@ func TestEventID(t *testing.T) { }) } } + +func TestStreamableStateless(t *testing.T) { + initReq := req(1, methodInitialize, &InitializeParams{}) + initResp := resp(1, &InitializeResult{ + Capabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, + }, + ProtocolVersion: latestProtocolVersion, + ServerInfo: &Implementation{Name: "test", Version: "v1.0.0"}, + }, nil) + // This version of sayHi expects + // that request from our client). + sayHi := func(ctx context.Context, req *CallToolRequest, args hiParams) (*CallToolResult, any, error) { + if err := req.Session.Ping(ctx, nil); err == nil { + // ping should fail, but not break the connection + t.Errorf("ping succeeded unexpectedly") + } + return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + args.Name}}}, nil, nil + } + + requests := []streamableRequest{ + { + method: "POST", + messages: []jsonrpc.Message{initReq}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{initResp}, + wantSessionID: false, // sessionless + }, + { + method: "POST", + wantStatusCode: http.StatusOK, + messages: []jsonrpc.Message{req(1, "tools/list", struct{}{})}, + wantBodyContaining: "greet", + }, + { + method: "GET", + wantStatusCode: http.StatusMethodNotAllowed, + }, + { + method: "POST", + wantStatusCode: http.StatusOK, + messages: []jsonrpc.Message{ + req(2, "tools/call", &CallToolParams{Name: "greet", Arguments: hiParams{Name: "World"}}), + }, + wantMessages: []jsonrpc.Message{ + resp(2, &CallToolResult{ + Content: []Content{&TextContent{Text: "hi World"}}, + }, nil), + }, + }, + { + method: "POST", + wantStatusCode: http.StatusOK, + messages: []jsonrpc.Message{ + req(2, "tools/call", &CallToolParams{Name: "greet", Arguments: hiParams{Name: "foo"}}), + }, + wantMessages: []jsonrpc.Message{ + resp(2, &CallToolResult{ + Content: []Content{&TextContent{Text: "hi foo"}}, + }, nil), + }, + }, + } + + testClientCompatibility := func(t *testing.T, handler http.Handler) { + ctx := context.Background() + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + cs, err := NewClient(testImpl, nil).Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil) + if err != nil { + t.Fatal(err) + } + res, err := cs.CallTool(ctx, &CallToolParams{Name: "greet", Arguments: hiParams{Name: "bar"}}) + if err != nil { + t.Fatal(err) + } + if got, want := textContent(t, res), "hi bar"; got != want { + t.Errorf("Result = %q, want %q", got, want) + } + } + + sessionlessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { + // Return a stateless server which never assigns a session ID. + server := NewServer(testImpl, &ServerOptions{ + GetSessionID: func() string { return "" }, + }) + AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + return server + }, &StreamableHTTPOptions{ + Stateless: true, + }) + + // First, test the "sessionless" stateless mode, where there is no session ID. + t.Run("sessionless", func(t *testing.T) { + testStreamableHandler(t, sessionlessHandler, requests) + testClientCompatibility(t, sessionlessHandler) + }) + + // Next, test the default stateless mode, where session IDs are permitted. + // + // This can be used by tools to look up application state preserved across + // subsequent requests. + requests[0].wantSessionID = true // now expect a session ID for initialize + statelessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { + // Return a server with default options which should assign a random session ID. + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + return server + }, &StreamableHTTPOptions{ + Stateless: true, + }) + t.Run("stateless", func(t *testing.T) { + testStreamableHandler(t, statelessHandler, requests) + testClientCompatibility(t, sessionlessHandler) + }) +} + +func textContent(t *testing.T, res *CallToolResult) string { + t.Helper() + if len(res.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(res.Content)) + } + text, ok := res.Content[0].(*TextContent) + if !ok { + t.Fatalf("Content[0] is %T, want *TextContent", res.Content[0]) + } + return text.Text +} + +func TestTokenInfo(t *testing.T) { + defer func(b bool) { testAuth = b }(testAuth) + testAuth = true + ctx := context.Background() + + // Create a server with a tool that returns TokenInfo. + tokenInfo := func(ctx context.Context, req *CallToolRequest, _ struct{}) (*CallToolResult, any, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.Extra.TokenInfo)}}}, nil, nil + } + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo) + + streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + verifier := func(context.Context, string, *http.Request) (*auth.TokenInfo, error) { + return &auth.TokenInfo{ + Scopes: []string{"scope"}, + // Expiration is far, far in the future. + Expiration: time.Date(5000, 1, 2, 3, 4, 5, 0, time.UTC), + }, nil + } + handler := auth.RequireBearerToken(verifier, nil)(streamHandler) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer session.Close() + + res, err := session.CallTool(ctx, &CallToolParams{Name: "tokenInfo"}) + if err != nil { + t.Fatal(err) + } + if len(res.Content) == 0 { + t.Fatal("missing content") + } + tc, ok := res.Content[0].(*TextContent) + if !ok { + t.Fatal("not TextContent") + } + if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w { + t.Errorf("got %q, want %q", g, w) + } +} + +func TestStreamableGET(t *testing.T) { + // This test checks the fix for problematic behavior described in #410: + // Hanging GET headers should be written immediately, even if there are no + // messages. + server := NewServer(testImpl, nil) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + newReq := func(method string, msg jsonrpc.Message) *http.Request { + var body io.Reader + if msg != nil { + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + t.Fatal(err) + } + body = bytes.NewReader(data) + } + req, err := http.NewRequestWithContext(ctx, method, httpServer.URL, body) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "application/json, text/event-stream") + if msg != nil { + req.Header.Set("Content-Type", "application/json") + } + return req + } + + get1 := newReq(http.MethodGet, nil) + resp, err := http.DefaultClient.Do(get1) + if err != nil { + t.Fatal(err) + } + if got, want := resp.StatusCode, http.StatusMethodNotAllowed; got != want { + t.Errorf("initial GET: got status %d, want %d", got, want) + } + defer resp.Body.Close() + + post1 := newReq(http.MethodPost, req(1, methodInitialize, &InitializeParams{})) + resp, err = http.DefaultClient.Do(post1) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if got, want := resp.StatusCode, http.StatusOK; got != want { + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + t.Errorf("initialize POST: got status %d, want %d; body:\n%s", got, want, string(body)) + } + + sessionID := resp.Header.Get(sessionIDHeader) + if sessionID == "" { + t.Fatalf("initialized missing session ID") + } + + get2 := newReq("GET", nil) + get2.Header.Set(sessionIDHeader, sessionID) + resp, err = http.DefaultClient.Do(get2) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("GET with session ID: got status %d, want %d", got, want) + } +} diff --git a/mcp/testdata/conformance/server/bad_requests.txtar b/mcp/testdata/conformance/server/bad_requests.txtar new file mode 100644 index 00000000..44816189 --- /dev/null +++ b/mcp/testdata/conformance/server/bad_requests.txtar @@ -0,0 +1,105 @@ +Check robustness to missing fields: servers should reject and otherwise ignore +bad requests. + +Fixed bugs: +- No id in 'initialize' should not panic (#197). +- No id in 'ping' should not panic (#194). +- No params in 'initialize' should not panic (#195). +- Notifications with IDs should not be treated like requests. (#196) +- No params in 'logging/setLevel' should not panic. +- No params in 'completion/complete' should not panic. +- JSON null params should also not panic in these cases. + +-- prompts -- +code_review + +-- client -- +{ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } + } +} +{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize" +} +{ + "jsonrpc": "2.0", + "id": 2, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } + } +} +{ "jsonrpc":"2.0", "id": 3, "method": "notifications/initialized" } +{ "jsonrpc":"2.0", "method": "notifications/initialized" } +{ "jsonrpc":"2.0", "method":"ping" } +{ "jsonrpc":"2.0", "id": 4, "method": "logging/setLevel" } +{ "jsonrpc":"2.0", "id": 5, "method": "completion/complete" } +{ "jsonrpc":"2.0", "id": 4, "method": "logging/setLevel", "params": null } + +-- server -- +{ + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32600, + "message": "invalid request: missing required \"params\"" + } +} +{ + "jsonrpc": "2.0", + "id": 2, + "result": { + "capabilities": { + "logging": {}, + "prompts": { + "listChanged": true + } + }, + "protocolVersion": "2024-11-05", + "serverInfo": { + "name": "testServer", + "version": "v1.0.0" + } + } +} +{ + "jsonrpc": "2.0", + "id": 3, + "error": { + "code": -32600, + "message": "invalid request: unexpected id for \"notifications/initialized\"" + } +} +{ + "jsonrpc": "2.0", + "id": 4, + "error": { + "code": -32600, + "message": "invalid request: missing required \"params\"" + } +} +{ + "jsonrpc": "2.0", + "id": 5, + "error": { + "code": -32600, + "message": "invalid request: missing required \"params\"" + } +} +{ + "jsonrpc": "2.0", + "id": 4, + "error": { + "code": -32600, + "message": "handling 'logging/setLevel': invalid request: missing required \"params\"" + } +} diff --git a/mcp/testdata/conformance/server/lifecycle.txtar b/mcp/testdata/conformance/server/lifecycle.txtar new file mode 100644 index 00000000..0a8cf34b --- /dev/null +++ b/mcp/testdata/conformance/server/lifecycle.txtar @@ -0,0 +1,65 @@ +This test checks that the server obeys the rules for initialization lifecycle, +and rejects non-ping requests until 'initialized' is received. + +See also modelcontextprotocol/go-sdk#225. + +-- client -- +{ "jsonrpc":"2.0", "method": "notifications/initialized" } +{ "jsonrpc": "2.0", "id": 2, "method": "tools/list" } +{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } + } +} +{ "jsonrpc":"2.0", "id": 1, "method":"ping" } +{ "jsonrpc": "2.0", "id": 2, "method": "tools/list" } +{ "jsonrpc":"2.0", "method": "notifications/initialized" } +{ "jsonrpc": "2.0", "id": 3, "method": "tools/list" } + +-- server -- +{ + "jsonrpc": "2.0", + "id": 2, + "error": { + "code": 0, + "message": "method \"tools/list\" is invalid during session initialization" + } +} +{ + "jsonrpc": "2.0", + "id": 1, + "result": { + "capabilities": { + "logging": {} + }, + "protocolVersion": "2024-11-05", + "serverInfo": { + "name": "testServer", + "version": "v1.0.0" + } + } +} +{ + "jsonrpc": "2.0", + "id": 1, + "result": {} +} +{ + "jsonrpc": "2.0", + "id": 2, + "result": { + "tools": [] + } +} +{ + "jsonrpc": "2.0", + "id": 3, + "result": { + "tools": [] + } +} diff --git a/mcp/testdata/conformance/server/prompts.txtar b/mcp/testdata/conformance/server/prompts.txtar index 6168ce8e..fdaf7932 100644 --- a/mcp/testdata/conformance/server/prompts.txtar +++ b/mcp/testdata/conformance/server/prompts.txtar @@ -1,7 +1,8 @@ Check behavior of a server with just prompts. Fixed bugs: -- empty tools lists should not be returned as 'null' +- Empty tools lists should not be returned as 'null'. +- No params in 'prompts/get' should not panic. -- prompts -- code_review @@ -17,24 +18,20 @@ code_review "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } } } +{ "jsonrpc":"2.0", "method": "notifications/initialized" } { "jsonrpc": "2.0", "id": 2, "method": "tools/list" } { "jsonrpc": "2.0", "id": 4, "method": "prompts/list" } +{ "jsonrpc": "2.0", "id": 5, "method": "prompts/get" } + -- server -- { "jsonrpc": "2.0", "id": 1, "result": { "capabilities": { - "completions": {}, "logging": {}, "prompts": { "listChanged": true - }, - "resources": { - "listChanged": true - }, - "tools": { - "listChanged": true } }, "protocolVersion": "2024-11-05", @@ -69,3 +66,11 @@ code_review ] } } +{ + "jsonrpc": "2.0", + "id": 5, + "error": { + "code": -32600, + "message": "invalid request: missing required \"params\"" + } +} diff --git a/mcp/testdata/conformance/server/resources.txtar b/mcp/testdata/conformance/server/resources.txtar index 3e7031ad..314817b8 100644 --- a/mcp/testdata/conformance/server/resources.txtar +++ b/mcp/testdata/conformance/server/resources.txtar @@ -2,6 +2,9 @@ Check behavior of a server with just resources. Fixed bugs: - A resource result holds a slice of contents, not just one. +- No params in 'resource/read' should not panic. +- No params in 'resources/subscribe' should not panic. +- No params in 'resources/unsubscribe' should not panic. -- resources -- info @@ -18,6 +21,7 @@ info.txt "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } } } +{ "jsonrpc":"2.0", "method": "notifications/initialized" } { "jsonrpc": "2.0", "id": 2, "method": "resources/list" } { "jsonrpc": "2.0", "id": 3, @@ -39,22 +43,17 @@ info.txt "roots": [] } } +{ "jsonrpc": "2.0", "id": 4, "method": "resources/read" } +{ "jsonrpc": "2.0", "id": 5, "method": "resources/subscribe" } -- server -- { "jsonrpc": "2.0", "id": 1, "result": { "capabilities": { - "completions": {}, "logging": {}, - "prompts": { - "listChanged": true - }, "resources": { "listChanged": true - }, - "tools": { - "listChanged": true } }, "protocolVersion": "2024-11-05", @@ -113,3 +112,19 @@ info.txt ] } } +{ + "jsonrpc": "2.0", + "id": 4, + "error": { + "code": -32600, + "message": "invalid request: missing required \"params\"" + } +} +{ + "jsonrpc": "2.0", + "id": 5, + "error": { + "code": -32600, + "message": "invalid request: missing required \"params\"" + } +} diff --git a/mcp/testdata/conformance/server/tools.txtar b/mcp/testdata/conformance/server/tools.txtar index a43cd075..b582dda8 100644 --- a/mcp/testdata/conformance/server/tools.txtar +++ b/mcp/testdata/conformance/server/tools.txtar @@ -4,9 +4,18 @@ Fixed bugs: - "tools/list" can have missing params - "_meta" should not be nil - empty resource or prompts should not be returned as 'null' +- the server should not crash when params are passed to tools/call +- missing required input fields should be rejected (#449) +- output and input should be validated against their actual json, not Go + representation +- When arguments are missing, the request should succeed if all properties are + optional, and observe any default values. -- tools -- greet +structured +tomorrow +inc -- client -- { @@ -19,23 +28,27 @@ greet "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } } } +{ "jsonrpc":"2.0", "method": "notifications/initialized" } { "jsonrpc": "2.0", "id": 2, "method": "tools/list" } { "jsonrpc": "2.0", "id": 3, "method": "resources/list" } { "jsonrpc": "2.0", "id": 4, "method": "prompts/list" } +{ "jsonrpc": "2.0", "id": 5, "method": "tools/call" } +{ "jsonrpc": "2.0", "id": 6, "method": "tools/call", "params": {"name": "greet", "arguments": {"Name": "you"} } } +{ "jsonrpc": "2.0", "id": 1, "result": {} } +{ "jsonrpc": "2.0", "id": 7, "method": "tools/call", "params": {"name": "structured", "arguments": {"In": "input"} } } +{ "jsonrpc": "2.0", "id": 8, "method": "tools/call", "params": {"name": "structured", "arguments": {} } } +{ "jsonrpc": "2.0", "id": 9, "method": "tools/call", "params": {"name": "tomorrow", "arguments": { "Now": "2025-06-18T15:04:05Z" } } } +{ "jsonrpc": "2.0", "id": 10, "method": "tools/call", "params": {"name": "greet" } } +{ "jsonrpc": "2.0", "id": 11, "method": "tools/call", "params": {"name": "inc", "arguments": { "x": 3 } } } +{ "jsonrpc": "2.0", "id": 11, "method": "tools/call", "params": {"name": "inc" } } + -- server -- { "jsonrpc": "2.0", "id": 1, "result": { "capabilities": { - "completions": {}, "logging": {}, - "prompts": { - "listChanged": true - }, - "resources": { - "listChanged": true - }, "tools": { "listChanged": true } @@ -64,11 +77,90 @@ greet "type": "string" } }, - "additionalProperties": { - "not": {} - } + "additionalProperties": false }, "name": "greet" + }, + { + "inputSchema": { + "type": "object", + "properties": { + "x": { + "type": "integer", + "default": 6 + } + }, + "additionalProperties": false + }, + "name": "inc", + "outputSchema": { + "type": "object", + "required": [ + "y" + ], + "properties": { + "y": { + "type": "integer" + } + }, + "additionalProperties": false + } + }, + { + "inputSchema": { + "type": "object", + "required": [ + "In" + ], + "properties": { + "In": { + "type": "string", + "description": "the input" + } + }, + "additionalProperties": false + }, + "name": "structured", + "outputSchema": { + "type": "object", + "required": [ + "Out" + ], + "properties": { + "Out": { + "type": "string", + "description": "the output" + } + }, + "additionalProperties": false + } + }, + { + "inputSchema": { + "type": "object", + "required": [ + "Now" + ], + "properties": { + "Now": { + "type": "string" + } + }, + "additionalProperties": false + }, + "name": "tomorrow", + "outputSchema": { + "type": "object", + "required": [ + "Tomorrow" + ], + "properties": { + "Tomorrow": { + "type": "string" + } + }, + "additionalProperties": false + } } ] } @@ -87,3 +179,104 @@ greet "prompts": [] } } +{ + "jsonrpc": "2.0", + "id": 5, + "error": { + "code": -32600, + "message": "invalid request: missing required \"params\"" + } +} +{ + "jsonrpc": "2.0", + "id": 1, + "method": "ping" +} +{ + "jsonrpc": "2.0", + "id": 6, + "result": { + "content": [ + { + "type": "text", + "text": "hi you" + } + ] + } +} +{ + "jsonrpc": "2.0", + "id": 7, + "result": { + "content": [ + { + "type": "text", + "text": "{\"Out\":\"Ack input\"}" + } + ], + "structuredContent": { + "Out": "Ack input" + } + } +} +{ + "jsonrpc": "2.0", + "id": 8, + "error": { + "code": -32602, + "message": "invalid params: validating \"arguments\": validating root: required: missing properties: [\"In\"]" + } +} +{ + "jsonrpc": "2.0", + "id": 9, + "result": { + "content": [ + { + "type": "text", + "text": "{\"Tomorrow\":\"2025-06-19T15:04:05Z\"}" + } + ], + "structuredContent": { + "Tomorrow": "2025-06-19T15:04:05Z" + } + } +} +{ + "jsonrpc": "2.0", + "id": 10, + "error": { + "code": -32602, + "message": "invalid params: validating \"arguments\": validating root: required: missing properties: [\"Name\"]" + } +} +{ + "jsonrpc": "2.0", + "id": 11, + "result": { + "content": [ + { + "type": "text", + "text": "{\"y\":4}" + } + ], + "structuredContent": { + "y": 4 + } + } +} +{ + "jsonrpc": "2.0", + "id": 11, + "result": { + "content": [ + { + "type": "text", + "text": "{\"y\":7}" + } + ], + "structuredContent": { + "y": 7 + } + } +} diff --git a/mcp/testdata/conformance/server/version-latest.txtar b/mcp/testdata/conformance/server/version-latest.txtar index 760bf8b7..75317676 100644 --- a/mcp/testdata/conformance/server/version-latest.txtar +++ b/mcp/testdata/conformance/server/version-latest.txtar @@ -18,19 +18,9 @@ response with its latest supported version. "id": 1, "result": { "capabilities": { - "completions": {}, - "logging": {}, - "prompts": { - "listChanged": true - }, - "resources": { - "listChanged": true - }, - "tools": { - "listChanged": true - } + "logging": {} }, - "protocolVersion": "2025-03-26", + "protocolVersion": "2025-06-18", "serverInfo": { "name": "testServer", "version": "v1.0.0" diff --git a/mcp/testdata/conformance/server/version-older.txtar b/mcp/testdata/conformance/server/version-older.txtar index 97f7b79b..82292630 100644 --- a/mcp/testdata/conformance/server/version-older.txtar +++ b/mcp/testdata/conformance/server/version-older.txtar @@ -18,17 +18,7 @@ support. "id": 1, "result": { "capabilities": { - "completions": {}, - "logging": {}, - "prompts": { - "listChanged": true - }, - "resources": { - "listChanged": true - }, - "tools": { - "listChanged": true - } + "logging": {} }, "protocolVersion": "2024-11-05", "serverInfo": { diff --git a/mcp/tool.go b/mcp/tool.go index a6f228eb..12b02b7b 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -5,274 +5,99 @@ package mcp import ( - "bytes" "context" "encoding/json" "fmt" - "slices" - "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/google/jsonschema-go/jsonschema" ) // A ToolHandler handles a call to tools/call. -// [CallToolParams.Arguments] will contain a map[string]any that has been validated -// against the input schema. -// TODO: Perhaps this should be an alias for ToolHandlerFor[map[string]any, map[string]any]? -type ToolHandler func(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) +// +// This is a low-level API, for use with [Server.AddTool]. It does not do any +// pre- or post-processing of the request or result: the params contain raw +// arguments, no input validation is performed, and the result is returned to +// the user as-is, without any validation of the output. +// +// Most users will write a [ToolHandlerFor] and install it with the generic +// [AddTool] function. +// +// If ToolHandler returns an error, it is treated as a protocol error. By +// contrast, [ToolHandlerFor] automatically populates [CallToolResult.IsError] +// and [CallToolResult.Content] accordingly. +type ToolHandler func(context.Context, *CallToolRequest) (*CallToolResult, error) // A ToolHandlerFor handles a call to tools/call with typed arguments and results. -type ToolHandlerFor[In, Out any] func(context.Context, *ServerSession, *CallToolParamsFor[In]) (*CallToolResultFor[Out], error) - -// A rawToolHandler is like a ToolHandler, but takes the arguments as as json.RawMessage. -type rawToolHandler = func(context.Context, *ServerSession, *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) - -// A ServerTool is a tool definition that is bound to a tool handler. -type ServerTool struct { - Tool *Tool - Handler ToolHandler - // Set in NewServerTool or Server.addToolsErr. - rawHandler rawToolHandler - // Resolved tool schemas. Set in Server.addToolsErr. - inputResolved, outputResolved *jsonschema.Resolved -} - -// NewServerTool is a helper to make a tool using reflection on the given type parameters. -// When the tool is called, CallToolParams.Arguments will be of type In. // -// If provided, variadic [ToolOption] values may be used to customize the tool. +// Use [AddTool] to add a ToolHandlerFor to a server. // -// The input schema for the tool is extracted from the request type for the -// handler, and used to unmmarshal and validate requests to the handler. This -// schema may be customized using the [Input] option. +// Unlike [ToolHandler], [ToolHandlerFor] provides significant functionality +// out of the box, and enforces that the tool conforms to the MCP spec: +// - The In type provides a default input schema for the tool, though it may +// be overridden in [AddTool]. +// - The input value is automatically unmarshaled from req.Params.Arguments. +// - The input value is automatically validated against its input schema. +// Invalid input is rejected before getting to the handler. +// - If the Out type is not the empty interface [any], it provides the +// default output schema for the tool (which again may be overridden in +// [AddTool]). +// - The Out value is used to populate result.StructuredOutput. +// - If [CallToolResult.Content] is unset, it is populated with the JSON +// content of the output. +// - An error result is treated as a tool error, rather than a protocol +// error, and is therefore packed into CallToolResult.Content, with +// [IsError] set. // -// TODO(jba): check that structured content is set in response. -func NewServerTool[In, Out any](name, description string, handler ToolHandlerFor[In, Out], opts ...ToolOption) *ServerTool { - st, err := newServerToolErr[In, Out](name, description, handler, opts...) - if err != nil { - panic(fmt.Errorf("NewServerTool(%q): %w", name, err)) - } - return st -} - -func newServerToolErr[In, Out any](name, description string, handler ToolHandlerFor[In, Out], opts ...ToolOption) (*ServerTool, error) { - // TODO: check that In is a struct. - ischema, err := jsonschema.For[In]() - if err != nil { - return nil, err - } - // TODO: uncomment when output schemas drop. - // oschema, err := jsonschema.For[TRes]() - // if err != nil { - // return nil, err - // } - - t := &ServerTool{ - Tool: &Tool{ - Name: name, - Description: description, - InputSchema: ischema, - // OutputSchema: oschema, - }, - } - for _, opt := range opts { - opt.set(t) - } - - t.rawHandler = func(ctx context.Context, ss *ServerSession, rparams *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) { - var args In - if rparams.Arguments != nil { - if err := unmarshalSchema(rparams.Arguments, t.inputResolved, &args); err != nil { - return nil, err - } - } - // TODO(jba): future-proof this copy. - params := &CallToolParamsFor[In]{ - Meta: rparams.Meta, - Name: rparams.Name, - Arguments: args, - } - res, err := handler(ctx, ss, params) - if err != nil { - return nil, err - } +// For these reasons, most users can ignore the [CallToolRequest] argument and +// [CallToolResult] return values entirely. In fact, it is permissible to +// return a nil CallToolResult, if you only care about returning a output value +// or error. The effective result will be populated as described above. +type ToolHandlerFor[In, Out any] func(_ context.Context, request *CallToolRequest, input In) (result *CallToolResult, output Out, _ error) - var ctr CallToolResult - if res != nil { - // TODO(jba): future-proof this copy. - ctr.Meta = res.Meta - ctr.Content = res.Content - ctr.IsError = res.IsError - } - return &ctr, nil - } - return t, nil +// A serverTool is a tool definition that is bound to a tool handler. +type serverTool struct { + tool *Tool + handler ToolHandler } -// newRawHandler creates a rawToolHandler for tools not created through NewServerTool. -// It unmarshals the arguments into a map[string]any and validates them against the -// schema, then calls the ServerTool's handler. -func newRawHandler(st *ServerTool) rawToolHandler { - if st.Handler == nil { - panic("st.Handler is nil") - } - return func(ctx context.Context, ss *ServerSession, rparams *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) { - // Unmarshal the args into what should be a map. - var args map[string]any - if rparams.Arguments != nil { - if err := unmarshalSchema(rparams.Arguments, st.inputResolved, &args); err != nil { - return nil, err - } - } - // TODO: generate copy - params := &CallToolParamsFor[map[string]any]{ - Meta: rparams.Meta, - Name: rparams.Name, - Arguments: args, - } - res, err := st.Handler(ctx, ss, params) - // TODO(rfindley): investigate why server errors are embedded in this strange way, - // rather than returned as jsonrpc2 server errors. - if err != nil { - return &CallToolResult{ - Content: []Content{&TextContent{Text: err.Error()}}, - IsError: true, - }, nil - } - return res, nil - } -} - -// unmarshalSchema unmarshals data into v and validates the result according to -// the given resolved schema. -func unmarshalSchema(data json.RawMessage, resolved *jsonschema.Resolved, v any) error { +// applySchema validates whether data is valid JSON according to the provided +// schema, after applying schema defaults. +// +// Returns the JSON value augmented with defaults. +func applySchema(data json.RawMessage, resolved *jsonschema.Resolved) (json.RawMessage, error) { // TODO: use reflection to create the struct type to unmarshal into. // Separate validation from assignment. - // Disallow unknown fields. - // Otherwise, if the tool was built with a struct, the client could send extra - // fields and json.Unmarshal would ignore them, so the schema would never get - // a chance to declare the extra args invalid. - dec := json.NewDecoder(bytes.NewReader(data)) - dec.DisallowUnknownFields() - if err := dec.Decode(v); err != nil { - return fmt.Errorf("unmarshaling: %w", err) - } - // TODO: test with nil args. + // Use default JSON marshalling for validation. + // + // This avoids inconsistent representation due to custom marshallers, such as + // time.Time (issue #449). + // + // Additionally, unmarshalling into a map ensures that the resulting JSON is + // at least {}, even if data is empty. For example, arguments is technically + // an optional property of callToolParams, and we still want to apply the + // defaults in this case. + // + // TODO(rfindley): in which cases can resolved be nil? if resolved != nil { - if err := resolved.ApplyDefaults(v); err != nil { - return fmt.Errorf("applying defaults from \n\t%s\nto\n\t%s:\n%w", schemaJSON(resolved.Schema()), data, err) - } - if err := resolved.Validate(v); err != nil { - return fmt.Errorf("validating\n\t%s\nagainst\n\t %s:\n %w", data, schemaJSON(resolved.Schema()), err) + v := make(map[string]any) + if len(data) > 0 { + if err := json.Unmarshal(data, &v); err != nil { + return nil, fmt.Errorf("unmarshaling arguments: %w", err) + } } - } - return nil -} - -// A ToolOption configures the behavior of a Tool. -type ToolOption interface { - set(*ServerTool) -} - -type toolSetter func(*ServerTool) - -func (s toolSetter) set(t *ServerTool) { s(t) } - -// Input applies the provided [SchemaOption] configuration to the tool's input -// schema. -func Input(opts ...SchemaOption) ToolOption { - return toolSetter(func(t *ServerTool) { - for _, opt := range opts { - opt.set(t.Tool.InputSchema) + if err := resolved.ApplyDefaults(&v); err != nil { + return nil, fmt.Errorf("applying schema defaults:\n%w", err) } - }) -} - -// A SchemaOption configures a jsonschema.Schema. -type SchemaOption interface { - set(s *jsonschema.Schema) -} - -type schemaSetter func(*jsonschema.Schema) - -func (s schemaSetter) set(schema *jsonschema.Schema) { s(schema) } - -// Property configures the schema for the property of the given name. -// If there is no such property in the schema, it is created. -func Property(name string, opts ...SchemaOption) SchemaOption { - return schemaSetter(func(schema *jsonschema.Schema) { - propSchema, ok := schema.Properties[name] - if !ok { - propSchema = new(jsonschema.Schema) - schema.Properties[name] = propSchema + if err := resolved.Validate(&v); err != nil { + return nil, err } - // Apply the options, with special handling for Required, as it needs to be - // set on the parent schema. - for _, opt := range opts { - if req, ok := opt.(required); ok { - if req { - if !slices.Contains(schema.Required, name) { - schema.Required = append(schema.Required, name) - } - } else { - schema.Required = slices.DeleteFunc(schema.Required, func(s string) bool { - return s == name - }) - } - } else { - opt.set(propSchema) - } + // We must re-marshal with the default values applied. + var err error + data, err = json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("marshalling with defaults: %v", err) } - }) -} - -// Required sets whether the associated property is required. It is only valid -// when used in a [Property] option: using Required outside of Property panics. -func Required(v bool) SchemaOption { - return required(v) -} - -// required must be a distinguished type as it needs special handling to mutate -// the parent schema, and to mutate prompt arguments. -type required bool - -func (required) set(s *jsonschema.Schema) { - panic("use of required outside of Property") -} - -// Enum sets the provided values as the "enum" value of the schema. -func Enum(values ...any) SchemaOption { - return schemaSetter(func(s *jsonschema.Schema) { - s.Enum = values - }) -} - -// Description sets the provided schema description. -func Description(desc string) SchemaOption { - return description(desc) -} - -// description must be a distinguished type so that it can be handled by prompt -// options. -type description string - -func (d description) set(s *jsonschema.Schema) { - s.Description = string(d) -} - -// Schema overrides the inferred schema with a shallow copy of the given -// schema. -func Schema(schema *jsonschema.Schema) SchemaOption { - return schemaSetter(func(s *jsonschema.Schema) { - *s = *schema - }) -} - -// schemaJSON returns the JSON value for s as a string, or a string indicating an error. -func schemaJSON(s *jsonschema.Schema) string { - m, err := json.Marshal(s) - if err != nil { - return fmt.Sprintf("", err) } - return string(m) + return data, nil } diff --git a/mcp/tool_example_test.go b/mcp/tool_example_test.go new file mode 100644 index 00000000..8f3fbbe6 --- /dev/null +++ b/mcp/tool_example_test.go @@ -0,0 +1,219 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "context" + "encoding/json" + "fmt" + "log" + "time" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func ExampleAddTool_customMarshalling() { + // Sometimes when you want to customize the input or output schema for a + // tool, you need to customize the schema of a single helper type that's used + // in several places. + // + // For example, suppose you had a type that marshals/unmarshals like a + // time.Time, and that type was used multiple times in your tool input. + type MyDate struct { + time.Time + } + type Input struct { + Query string `json:"query,omitempty"` + Start MyDate `json:"start,omitempty"` + End MyDate `json:"end,omitempty"` + } + + // In this case, you can use jsonschema.For along with jsonschema.ForOptions + // to customize the schema inference for your custom type. + inputSchema, err := jsonschema.For[Input](&jsonschema.ForOptions{ + TypeSchemas: map[any]*jsonschema.Schema{ + MyDate{}: {Type: "string"}, + }, + }) + if err != nil { + log.Fatal(err) + } + + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + toolHandler := func(context.Context, *mcp.CallToolRequest, Input) (*mcp.CallToolResult, any, error) { + panic("not implemented") + } + mcp.AddTool(server, &mcp.Tool{Name: "my_tool", InputSchema: inputSchema}, toolHandler) + + ctx := context.Background() + session, err := connect(ctx, server) // create an in-memory connection + if err != nil { + log.Fatal(err) + } + defer session.Close() + + for t, err := range session.Tools(ctx, nil) { + if err != nil { + log.Fatal(err) + } + schemaJSON, err := json.MarshalIndent(t.InputSchema, "", "\t") + if err != nil { + log.Fatal(err) + } + fmt.Println(t.Name, string(schemaJSON)) + } + // Output: + // my_tool { + // "type": "object", + // "properties": { + // "end": { + // "type": "string" + // }, + // "query": { + // "type": "string" + // }, + // "start": { + // "type": "string" + // } + // }, + // "additionalProperties": false + // } +} + +type Location struct { + Name string `json:"name"` + Latitude *float64 `json:"latitude,omitempty"` + Longitude *float64 `json:"longitude,omitempty"` +} + +type Forecast struct { + Forecast string `json:"forecast" jsonschema:"description of the day's weather"` + Type WeatherType `json:"type" jsonschema:"type of weather"` + Rain float64 `json:"rain" jsonschema:"probability of rain, between 0 and 1"` + High float64 `json:"high" jsonschema:"high temperature"` + Low float64 `json:"low" jsonschema:"low temperature"` +} + +type WeatherType string + +const ( + Sunny WeatherType = "sun" + PartlyCloudy WeatherType = "partly_cloudy" + Cloudy WeatherType = "clouds" + Rainy WeatherType = "rain" + Snowy WeatherType = "snow" +) + +type Probability float64 + +// !+weathertool + +type WeatherInput struct { + Location Location `json:"location" jsonschema:"user location"` + Days int `json:"days" jsonschema:"number of days to forecast"` +} + +type WeatherOutput struct { + Summary string `json:"summary" jsonschema:"a summary of the weather forecast"` + Confidence Probability `json:"confidence" jsonschema:"confidence, between 0 and 1"` + AsOf time.Time `json:"asOf" jsonschema:"the time the weather was computed"` + DailyForecast []Forecast `json:"dailyForecast" jsonschema:"the daily forecast"` + Source string `json:"source,omitempty" jsonschema:"the organization providing the weather forecast"` +} + +func WeatherTool(ctx context.Context, req *mcp.CallToolRequest, in WeatherInput) (*mcp.CallToolResult, WeatherOutput, error) { + perfectWeather := WeatherOutput{ + Summary: "perfect", + Confidence: 1.0, + AsOf: time.Now(), + } + for range in.Days { + perfectWeather.DailyForecast = append(perfectWeather.DailyForecast, Forecast{ + Forecast: "another perfect day", + Type: Sunny, + Rain: 0.0, + High: 72.0, + Low: 72.0, + }) + } + return nil, perfectWeather, nil +} + +// !-weathertool + +func ExampleAddTool_complexSchema() { + // This example demonstrates a tool with a more 'realistic' input and output + // schema. We use a combination of techniques to tune our input and output + // schemas. + + // !+customschemas + + // Distinguished Go types allow custom schemas to be reused during inference. + customSchemas := map[any]*jsonschema.Schema{ + Probability(0): {Type: "number", Minimum: jsonschema.Ptr(0.0), Maximum: jsonschema.Ptr(1.0)}, + WeatherType(""): {Type: "string", Enum: []any{Sunny, PartlyCloudy, Cloudy, Rainy, Snowy}}, + } + opts := &jsonschema.ForOptions{TypeSchemas: customSchemas} + in, err := jsonschema.For[WeatherInput](opts) + if err != nil { + log.Fatal(err) + } + + // Furthermore, we can tweak the inferred schema, in this case limiting + // forecasts to 0-10 days. + daysSchema := in.Properties["days"] + daysSchema.Minimum = jsonschema.Ptr(0.0) + daysSchema.Maximum = jsonschema.Ptr(10.0) + + // Output schema inference can reuse our custom schemas from input inference. + out, err := jsonschema.For[WeatherOutput](opts) + if err != nil { + log.Fatal(err) + } + + // Now add our tool to a server. Since we've customized the schemas, we need + // to override the default schema inference. + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + mcp.AddTool(server, &mcp.Tool{ + Name: "weather", + InputSchema: in, + OutputSchema: out, + }, WeatherTool) + + // !-customschemas + + ctx := context.Background() + session, err := connect(ctx, server) // create an in-memory connection + if err != nil { + log.Fatal(err) + } + defer session.Close() + + // Check that the client observes the correct schemas. + for t, err := range session.Tools(ctx, nil) { + if err != nil { + log.Fatal(err) + } + // Formatting the entire schemas would be too much output. + // Just check that our customizations were effective. + fmt.Println("max days:", *t.InputSchema.Properties["days"].Maximum) + fmt.Println("max confidence:", *t.OutputSchema.Properties["confidence"].Maximum) + fmt.Println("weather types:", t.OutputSchema.Properties["dailyForecast"].Items.Properties["type"].Enum) + } + // Output: + // max days: 10 + // max confidence: 1 + // weather types: [sun partly_cloudy clouds rain snow] +} + +func connect(ctx context.Context, server *mcp.Server) (*mcp.ClientSession, error) { + t1, t2 := mcp.NewInMemoryTransports() + if _, err := server.Connect(ctx, t1, nil); err != nil { + return nil, err + } + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + return client.Connect(ctx, t2, nil) +} diff --git a/mcp/tool_test.go b/mcp/tool_test.go index 85775e9b..ef26e9dc 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -7,90 +7,17 @@ package mcp import ( "context" "encoding/json" + "errors" + "fmt" "reflect" + "strings" "testing" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "github.com/modelcontextprotocol/go-sdk/jsonschema" + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) -// testToolHandler is used for type inference in TestNewServerTool. -func testToolHandler[T any](context.Context, *ServerSession, *CallToolParamsFor[T]) (*CallToolResultFor[any], error) { - panic("not implemented") -} - -func TestNewServerTool(t *testing.T) { - tests := []struct { - tool *ServerTool - want *jsonschema.Schema - }{ - { - NewServerTool("basic", "", testToolHandler[struct { - Name string `json:"name"` - }]), - &jsonschema.Schema{ - Type: "object", - Required: []string{"name"}, - Properties: map[string]*jsonschema.Schema{ - "name": {Type: "string"}, - }, - AdditionalProperties: &jsonschema.Schema{Not: new(jsonschema.Schema)}, - }, - }, - { - NewServerTool("enum", "", testToolHandler[struct{ Name string }], Input( - Property("Name", Enum("x", "y", "z")), - )), - &jsonschema.Schema{ - Type: "object", - Required: []string{"Name"}, - Properties: map[string]*jsonschema.Schema{ - "Name": {Type: "string", Enum: []any{"x", "y", "z"}}, - }, - AdditionalProperties: &jsonschema.Schema{Not: new(jsonschema.Schema)}, - }, - }, - { - NewServerTool("required", "", testToolHandler[struct { - Name string `json:"name"` - Language string `json:"language"` - X int `json:"x,omitempty"` - Y int `json:"y,omitempty"` - }], Input( - Property("x", Required(true)))), - &jsonschema.Schema{ - Type: "object", - Required: []string{"name", "language", "x"}, - Properties: map[string]*jsonschema.Schema{ - "language": {Type: "string"}, - "name": {Type: "string"}, - "x": {Type: "integer"}, - "y": {Type: "integer"}, - }, - AdditionalProperties: &jsonschema.Schema{Not: new(jsonschema.Schema)}, - }, - }, - { - NewServerTool("set_schema", "", testToolHandler[struct { - X int `json:"x,omitempty"` - Y int `json:"y,omitempty"` - }], Input( - Schema(&jsonschema.Schema{Type: "object"})), - ), - &jsonschema.Schema{ - Type: "object", - }, - }, - } - for _, test := range tests { - if diff := cmp.Diff(test.want, test.tool.Tool.InputSchema, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Errorf("NewServerTool(%v) mismatch (-want +got):\n%s", test.tool.Tool.Name, diff) - } - } -} - -func TestUnmarshalSchema(t *testing.T) { +func TestApplySchema(t *testing.T) { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -112,19 +39,109 @@ func TestUnmarshalSchema(t *testing.T) { want any }{ {`{"x": 1}`, new(S), &S{X: 1}}, - {`{}`, new(S), &S{X: 3}}, // default applied - {`{"x": 0}`, new(S), &S{X: 3}}, // FAIL: should be 0. (requires double unmarshal) + {`{}`, new(S), &S{X: 3}}, // default applied + {`{"x": 0}`, new(S), &S{X: 0}}, {`{"x": 1}`, new(map[string]any), &map[string]any{"x": 1.0}}, {`{}`, new(map[string]any), &map[string]any{"x": 3.0}}, // default applied {`{"x": 0}`, new(map[string]any), &map[string]any{"x": 0.0}}, } { raw := json.RawMessage(tt.data) - if err := unmarshalSchema(raw, resolved, tt.v); err != nil { + raw, err = applySchema(raw, resolved) + if err != nil { + t.Fatal(err) + } + if err := json.Unmarshal(raw, &tt.v); err != nil { t.Fatal(err) } if !reflect.DeepEqual(tt.v, tt.want) { t.Errorf("got %#v, want %#v", tt.v, tt.want) } + } +} + +func TestToolErrorHandling(t *testing.T) { + // Construct server and add both tools at the top level + server := NewServer(testImpl, nil) + + // Create a tool that returns a structured error + structuredErrorHandler := func(ctx context.Context, req *CallToolRequest, args map[string]any) (*CallToolResult, any, error) { + return nil, nil, &jsonrpc2.WireError{ + Code: CodeInvalidParams, + Message: "internal server error", + } + } + + // Create a tool that returns a regular error + regularErrorHandler := func(ctx context.Context, req *CallToolRequest, args map[string]any) (*CallToolResult, any, error) { + return nil, nil, fmt.Errorf("tool execution failed") + } + AddTool(server, &Tool{Name: "error_tool", Description: "returns structured error"}, structuredErrorHandler) + AddTool(server, &Tool{Name: "regular_error_tool", Description: "returns regular error"}, regularErrorHandler) + + // Connect server and client once + ct, st := NewInMemoryTransports() + _, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatal(err) } + + client := NewClient(testImpl, nil) + cs, err := client.Connect(context.Background(), ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Test that structured JSON-RPC errors are returned directly + t.Run("structured_error", func(t *testing.T) { + // Call the tool + _, err = cs.CallTool(context.Background(), &CallToolParams{ + Name: "error_tool", + Arguments: map[string]any{}, + }) + + // Should get the structured error directly + if err == nil { + t.Fatal("expected error, got nil") + } + + var wireErr *jsonrpc2.WireError + if !errors.As(err, &wireErr) { + t.Fatalf("expected WireError, got %[1]T: %[1]v", err) + } + + if wireErr.Code != CodeInvalidParams { + t.Errorf("expected error code %d, got %d", CodeInvalidParams, wireErr.Code) + } + }) + + // Test that regular errors are embedded in tool results + t.Run("regular_error", func(t *testing.T) { + // Call the tool + result, err := cs.CallTool(context.Background(), &CallToolParams{ + Name: "regular_error_tool", + Arguments: map[string]any{}, + }) + // Should not get an error at the protocol level + if err != nil { + t.Fatalf("unexpected protocol error: %v", err) + } + + // Should get a result with IsError=true + if !result.IsError { + t.Error("expected IsError=true, got false") + } + + // Should have error message in content + if len(result.Content) == 0 { + t.Error("expected error content, got empty") + } + + if textContent, ok := result.Content[0].(*TextContent); !ok { + t.Error("expected TextContent") + } else if !strings.Contains(textContent.Text, "tool execution failed") { + t.Errorf("expected error message in content, got: %s", textContent.Text) + } + }) } diff --git a/mcp/transport.go b/mcp/transport.go index 85bfaf65..d2109e7d 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -10,12 +10,14 @@ import ( "errors" "fmt" "io" + "log" "net" "os" "sync" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/xcontext" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) // ErrConnectionClosed is returned when sending a message to a connection that @@ -34,99 +36,120 @@ type Transport interface { Connect(ctx context.Context) (Connection, error) } -type ( - // JSONRPCID is a JSON-RPC request ID. - JSONRPCID = jsonrpc2.ID - // JSONRPCMessage is a JSON-RPC message. - JSONRPCMessage = jsonrpc2.Message - // JSONRPCRequest is a JSON-RPC request. - JSONRPCRequest = jsonrpc2.Request - // JSONRPCResponse is a JSON-RPC response. - JSONRPCResponse = jsonrpc2.Response -) - // A Connection is a logical bidirectional JSON-RPC connection. type Connection interface { - Read(context.Context) (JSONRPCMessage, error) - Write(context.Context, JSONRPCMessage) error - Close() error // may be called concurrently by both peers + // Read reads the next message to process off the connection. + // + // Connections must allow Read to be called concurrently with Close. In + // particular, calling Close should unblock a Read waiting for input. + Read(context.Context) (jsonrpc.Message, error) + + // Write writes a new message to the connection. + // + // Write may be called concurrently, as calls or reponses may occur + // concurrently in user code. + Write(context.Context, jsonrpc.Message) error + + // Close closes the connection. It is implicitly called whenever a Read or + // Write fails. + // + // Close may be called multiple times, potentially concurrently. + Close() error + + // TODO(#148): remove SessionID from this interface. SessionID() string } -// A StdioTransport is a [Transport] that communicates over stdin/stdout using -// newline-delimited JSON. -type StdioTransport struct { - ioTransport -} +// A ClientConnection is a [Connection] that is specific to the MCP client. +// +// If client connections implement this interface, they may receive information +// about changes to the client session. +// +// TODO: should this interface be exported? +type clientConnection interface { + Connection -// An ioTransport is a [Transport] that communicates using newline-delimited -// JSON over an io.ReadWriteCloser. -type ioTransport struct { - rwc io.ReadWriteCloser + // SessionUpdated is called whenever the client session state changes. + sessionUpdated(clientSessionState) } -func (t *ioTransport) Connect(context.Context) (Connection, error) { - return newIOConn(t.rwc), nil +// A serverConnection is a Connection that is specific to the MCP server. +// +// If server connections implement this interface, they receive information +// about changes to the server session. +// +// TODO: should this interface be exported? +type serverConnection interface { + Connection + sessionUpdated(ServerSessionState) } -// NewStdioTransport constructs a transport that communicates over -// stdin/stdout. -func NewStdioTransport() *StdioTransport { - return &StdioTransport{ioTransport{rwc{os.Stdin, os.Stdout}}} +// A StdioTransport is a [Transport] that communicates over stdin/stdout using +// newline-delimited JSON. +type StdioTransport struct{} + +// Connect implements the [Transport] interface. +func (*StdioTransport) Connect(context.Context) (Connection, error) { + return newIOConn(rwc{os.Stdin, os.Stdout}), nil } // An InMemoryTransport is a [Transport] that communicates over an in-memory // network connection, using newline-delimited JSON. type InMemoryTransport struct { - ioTransport + rwc io.ReadWriteCloser +} + +// Connect implements the [Transport] interface. +func (t *InMemoryTransport) Connect(context.Context) (Connection, error) { + return newIOConn(t.rwc), nil } -// NewInMemoryTransports returns two InMemoryTransports that connect to each +// NewInMemoryTransports returns two [InMemoryTransports] that connect to each // other. func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) { c1, c2 := net.Pipe() - return &InMemoryTransport{ioTransport{c1}}, &InMemoryTransport{ioTransport{c2}} + return &InMemoryTransport{c1}, &InMemoryTransport{c2} } -type binder[T handler] interface { - bind(*jsonrpc2.Connection) T +type binder[T handler, State any] interface { + // TODO(rfindley): the bind API has gotten too complicated. Simplify. + bind(Connection, *jsonrpc2.Connection, State, func()) T disconnect(T) } type handler interface { - handle(ctx context.Context, req *JSONRPCRequest) (any, error) - setConn(Connection) + handle(ctx context.Context, req *jsonrpc.Request) (any, error) } -func connect[H handler](ctx context.Context, t Transport, b binder[H]) (H, error) { +func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, State], s State, onClose func()) (H, error) { var zero H - conn, err := t.Connect(ctx) + mcpConn, err := t.Connect(ctx) if err != nil { return zero, err } // If logging is configured, write message logs. - reader, writer := jsonrpc2.Reader(conn), jsonrpc2.Writer(conn) + reader, writer := jsonrpc2.Reader(mcpConn), jsonrpc2.Writer(mcpConn) var ( h H preempter canceller ) bind := func(conn *jsonrpc2.Connection) jsonrpc2.Handler { - h = b.bind(conn) + h = b.bind(mcpConn, conn, s, onClose) preempter.conn = conn return jsonrpc2.HandlerFunc(h.handle) } _ = jsonrpc2.NewConnection(ctx, jsonrpc2.ConnectionConfig{ Reader: reader, Writer: writer, - Closer: conn, + Closer: mcpConn, Bind: bind, Preempter: &preempter, OnDone: func() { b.disconnect(h) }, + OnInternalError: func(err error) { log.Printf("jsonrpc2 error: %v", err) }, }) assert(preempter.conn != nil, "unbound preempter") - h.setConn(conn) return h, nil } @@ -136,9 +159,9 @@ type canceller struct { conn *jsonrpc2.Connection } -// Preempt implements jsonrpc2.Preempter. -func (c *canceller) Preempt(ctx context.Context, req *JSONRPCRequest) (result any, err error) { - if req.Method == "notifications/cancelled" { +// Preempt implements [jsonrpc2.Preempter]. +func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result any, err error) { + if req.Method == notificationCancelled { var params CancelledParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { return nil, err @@ -161,7 +184,7 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params err := call.Await(ctx, result) switch { case errors.Is(err, jsonrpc2.ErrClientClosing), errors.Is(err, jsonrpc2.ErrServerClosing): - return fmt.Errorf("calling %q: %w", method, ErrConnectionClosed) + return fmt.Errorf("%w: calling %q: %v", ErrConnectionClosed, method, err) case ctx.Err() != nil: // Notify the peer of cancellation. err := conn.Notify(xcontext.Detach(ctx), notificationCancelled, &CancelledParams{ @@ -178,59 +201,65 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params // A LoggingTransport is a [Transport] that delegates to another transport, // writing RPC logs to an io.Writer. type LoggingTransport struct { - delegate Transport - w io.Writer -} - -// NewLoggingTransport creates a new LoggingTransport that delegates to the -// provided transport, writing RPC logs to the provided io.Writer. -func NewLoggingTransport(delegate Transport, w io.Writer) *LoggingTransport { - return &LoggingTransport{delegate, w} + Transport Transport + Writer io.Writer } // Connect connects the underlying transport, returning a [Connection] that writes // logs to the configured destination. func (t *LoggingTransport) Connect(ctx context.Context) (Connection, error) { - delegate, err := t.delegate.Connect(ctx) + delegate, err := t.Transport.Connect(ctx) if err != nil { return nil, err } - return &loggingConn{delegate, t.w}, nil + return &loggingConn{delegate: delegate, w: t.Writer}, nil } type loggingConn struct { delegate Connection - w io.Writer + + mu sync.Mutex + w io.Writer } func (c *loggingConn) SessionID() string { return c.delegate.SessionID() } -// loggingReader is a stream middleware that logs incoming messages. -func (s *loggingConn) Read(ctx context.Context) (JSONRPCMessage, error) { +// Read is a stream middleware that logs incoming messages. +func (s *loggingConn) Read(ctx context.Context) (jsonrpc.Message, error) { msg, err := s.delegate.Read(ctx) + if err != nil { + s.mu.Lock() fmt.Fprintf(s.w, "read error: %v", err) + s.mu.Unlock() } else { data, err := jsonrpc2.EncodeMessage(msg) + s.mu.Lock() if err != nil { fmt.Fprintf(s.w, "LoggingTransport: failed to marshal: %v", err) } fmt.Fprintf(s.w, "read: %s\n", string(data)) + s.mu.Unlock() } + return msg, err } -// loggingWriter is a stream middleware that logs outgoing messages. -func (s *loggingConn) Write(ctx context.Context, msg JSONRPCMessage) error { +// Write is a stream middleware that logs outgoing messages. +func (s *loggingConn) Write(ctx context.Context, msg jsonrpc.Message) error { err := s.delegate.Write(ctx, msg) if err != nil { + s.mu.Lock() fmt.Fprintf(s.w, "write error: %v", err) + s.mu.Unlock() } else { data, err := jsonrpc2.EncodeMessage(msg) + s.mu.Lock() if err != nil { fmt.Fprintf(s.w, "LoggingTransport: failed to marshal: %v", err) } fmt.Fprintf(s.w, "write: %s\n", string(data)) + s.mu.Unlock() } return err } @@ -259,39 +288,103 @@ func (r rwc) Close() error { } // An ioConn is a transport that delimits messages with newlines across -// a bidirectional stream, and supports JSONRPC2 message batching. +// a bidirectional stream, and supports jsonrpc.2 message batching. // // See https://github.com/ndjson/ndjson-spec for discussion of newline // delimited JSON. // // See [msgBatch] for more discussion of message batching. type ioConn struct { - rwc io.ReadWriteCloser // the underlying stream - in *json.Decoder // a decoder bound to rwc + protocolVersion string // negotiated version, set during session initialization. + + writeMu sync.Mutex // guards Write, which must be concurrency safe. + rwc io.ReadWriteCloser // the underlying stream + + // incoming receives messages from the read loop started in [newIOConn]. + incoming <-chan msgOrErr // If outgoiBatch has a positive capacity, it will be used to batch requests // and notifications before sending. - outgoingBatch []JSONRPCMessage + outgoingBatch []jsonrpc.Message // Unread messages in the last batch. Since reads are serialized, there is no // need to guard here. - queue []JSONRPCMessage + queue []jsonrpc.Message // batches correlate incoming requests to the batch in which they arrived. // Since writes may be concurrent to reads, we need to guard this with a mutex. batchMu sync.Mutex batches map[jsonrpc2.ID]*msgBatch // lazily allocated + + closeOnce sync.Once + closed chan struct{} + closeErr error +} + +type msgOrErr struct { + msg json.RawMessage + err error } func newIOConn(rwc io.ReadWriteCloser) *ioConn { + var ( + incoming = make(chan msgOrErr) + closed = make(chan struct{}) + ) + // Start a goroutine for reads, so that we can select on the incoming channel + // in [ioConn.Read] and unblock the read as soon as Close is called (see #224). + // + // This leaks a goroutine if rwc.Read does not unblock after it is closed, + // but that is unavoidable since AFAIK there is no (easy and portable) way to + // guarantee that reads of stdin are unblocked when closed. + go func() { + dec := json.NewDecoder(rwc) + for { + var raw json.RawMessage + err := dec.Decode(&raw) + // If decoding was successful, check for trailing data at the end of the stream. + if err == nil { + // Read the next byte to check if there is trailing data. + var tr [1]byte + if n, readErr := dec.Buffered().Read(tr[:]); n > 0 { + // If read byte is not a newline, it is an error. + if tr[0] != '\n' { + err = fmt.Errorf("invalid trailing data at the end of stream") + } + } else if readErr != nil && readErr != io.EOF { + err = readErr + } + } + select { + case incoming <- msgOrErr{msg: raw, err: err}: + case <-closed: + return + } + if err != nil { + return + } + } + }() return &ioConn{ - rwc: rwc, - in: json.NewDecoder(rwc), + rwc: rwc, + incoming: incoming, + closed: closed, } } func (c *ioConn) SessionID() string { return "" } +func (c *ioConn) sessionUpdated(state ServerSessionState) { + protocolVersion := "" + if state.InitializeParams != nil { + protocolVersion = state.InitializeParams.ProtocolVersion + } + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + c.protocolVersion = negotiatedVersion(protocolVersion) +} + // addBatch records a msgBatch for an incoming batch payload. // It returns an error if batch is malformed, containing previously seen IDs. // @@ -319,7 +412,7 @@ func (t *ioConn) addBatch(batch *msgBatch) error { // The second result reports whether resp was part of a batch. If this is true, // the first result is nil if the batch is still incomplete, or the full set of // batch responses if resp completed the batch. -func (t *ioConn) updateBatch(resp *JSONRPCResponse) ([]*JSONRPCResponse, bool) { +func (t *ioConn) updateBatch(resp *jsonrpc.Response) ([]*jsonrpc.Response, bool) { t.batchMu.Lock() defer t.batchMu.Unlock() @@ -339,9 +432,9 @@ func (t *ioConn) updateBatch(resp *JSONRPCResponse) ([]*JSONRPCResponse, bool) { return nil, false } -// A msgBatch records information about an incoming batch of JSONRPC2 calls. +// A msgBatch records information about an incoming batch of jsonrpc.2 calls. // -// The JSONRPC2 spec (https://www.jsonrpc.org/specification#batch) says: +// The jsonrpc.2 spec (https://www.jsonrpc.org/specification#batch) says: // // "The Server should respond with an Array containing the corresponding // Response objects, after all of the batch Request objects have been @@ -354,14 +447,12 @@ func (t *ioConn) updateBatch(resp *JSONRPCResponse) ([]*JSONRPCResponse, bool) { // When there are no unresolved calls, the response payload is sent. type msgBatch struct { unresolved map[jsonrpc2.ID]int - responses []*JSONRPCResponse -} - -func (t *ioConn) Read(ctx context.Context) (JSONRPCMessage, error) { - return t.read(ctx, t.in) + responses []*jsonrpc.Response } -func (t *ioConn) read(ctx context.Context, in *json.Decoder) (JSONRPCMessage, error) { +func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { + // As a matter of principle, enforce that reads on a closed context return an + // error. select { case <-ctx.Done(): return nil, ctx.Err() @@ -374,19 +465,34 @@ func (t *ioConn) read(ctx context.Context, in *json.Decoder) (JSONRPCMessage, er } var raw json.RawMessage - if err := in.Decode(&raw); err != nil { - return nil, err + select { + case <-ctx.Done(): + return nil, ctx.Err() + + case v := <-t.incoming: + if v.err != nil { + return nil, v.err + } + raw = v.msg + + case <-t.closed: + return nil, io.EOF } + msgs, batch, err := readBatch(raw) if err != nil { return nil, err } + if batch && t.protocolVersion >= protocolVersion20250618 { + return nil, fmt.Errorf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, t.protocolVersion) + } + t.queue = msgs[1:] if batch { var respBatch *msgBatch // track incoming requests in the batch for _, msg := range msgs { - if req, ok := msg.(*JSONRPCRequest); ok { + if req, ok := msg.(*jsonrpc.Request); ok { if respBatch == nil { respBatch = &msgBatch{ unresolved: make(map[jsonrpc2.ID]int), @@ -411,7 +517,7 @@ func (t *ioConn) read(ctx context.Context, in *json.Decoder) (JSONRPCMessage, er // readBatch reads batch data, which may be either a single JSON-RPC message, // or an array of JSON-RPC messages. -func readBatch(data []byte) (msgs []JSONRPCMessage, isBatch bool, _ error) { +func readBatch(data []byte) (msgs []jsonrpc.Message, isBatch bool, _ error) { // Try to read an array of messages first. var rawBatch []json.RawMessage if err := json.Unmarshal(data, &rawBatch); err == nil { @@ -429,21 +535,25 @@ func readBatch(data []byte) (msgs []JSONRPCMessage, isBatch bool, _ error) { } // Try again with a single message. msg, err := jsonrpc2.DecodeMessage(data) - return []JSONRPCMessage{msg}, false, err + return []jsonrpc.Message{msg}, false, err } -func (t *ioConn) Write(ctx context.Context, msg JSONRPCMessage) error { +func (t *ioConn) Write(ctx context.Context, msg jsonrpc.Message) error { + // As in [ioConn.Read], enforce that Writes on a closed context are an error. select { case <-ctx.Done(): return ctx.Err() default: } + t.writeMu.Lock() + defer t.writeMu.Unlock() + // Batching support: if msg is a Response, it may have completed a batch, so // check that first. Otherwise, it is a request or notification, and we may // want to collect it into a batch before sending, if we're configured to use // outgoing batches. - if resp, ok := msg.(*JSONRPCResponse); ok { + if resp, ok := msg.(*jsonrpc.Response); ok { if batch, ok := t.updateBatch(resp); ok { if len(batch) > 0 { data, err := marshalMessages(batch) @@ -480,10 +590,14 @@ func (t *ioConn) Write(ctx context.Context, msg JSONRPCMessage) error { } func (t *ioConn) Close() error { - return t.rwc.Close() + t.closeOnce.Do(func() { + t.closeErr = t.rwc.Close() + close(t.closed) + }) + return t.closeErr } -func marshalMessages[T JSONRPCMessage](msgs []T) ([]byte, error) { +func marshalMessages[T jsonrpc.Message](msgs []T) ([]byte, error) { var rawMsgs []json.RawMessage for _, msg := range msgs { raw, err := jsonrpc2.EncodeMessage(msg) diff --git a/mcp/transport_example_test.go b/mcp/transport_example_test.go new file mode 100644 index 00000000..ab54a422 --- /dev/null +++ b/mcp/transport_example_test.go @@ -0,0 +1,53 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Uses strings.SplitSeq. +//go:build go1.24 + +package mcp_test + +import ( + "bytes" + "context" + "fmt" + "log" + "slices" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// !+loggingtransport + +func ExampleLoggingTransport() { + ctx := context.Background() + t1, t2 := mcp.NewInMemoryTransports() + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + serverSession, err := server.Connect(ctx, t1, nil) + if err != nil { + log.Fatal(err) + } + defer serverSession.Wait() + + client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) + var b bytes.Buffer + logTransport := &mcp.LoggingTransport{Transport: t2, Writer: &b} + clientSession, err := client.Connect(ctx, logTransport, nil) + if err != nil { + log.Fatal(err) + } + defer clientSession.Close() + + // Sort for stability: reads are concurrent to writes. + for _, line := range slices.Sorted(strings.SplitSeq(b.String(), "\n")) { + fmt.Println(line) + } + + // Output: + // read: {"jsonrpc":"2.0","id":1,"result":{"capabilities":{"logging":{}},"protocolVersion":"2025-06-18","serverInfo":{"name":"server","version":"v0.0.1"}}} + // write: {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"capabilities":{"roots":{"listChanged":true}},"clientInfo":{"name":"client","version":"v0.0.1"},"protocolVersion":"2025-06-18"}} + // write: {"jsonrpc":"2.0","method":"notifications/initialized","params":{}} +} + +// !-loggingtransport diff --git a/mcp/transport_test.go b/mcp/transport_test.go index db18a352..10804a87 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -7,9 +7,11 @@ package mcp import ( "context" "io" + "strings" "testing" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) func TestBatchFraming(t *testing.T) { @@ -22,10 +24,11 @@ func TestBatchFraming(t *testing.T) { r, w := io.Pipe() tport := newIOConn(rwc{r, w}) - tport.outgoingBatch = make([]JSONRPCMessage, 0, 2) + tport.outgoingBatch = make([]jsonrpc.Message, 0, 2) + defer tport.Close() // Read the two messages into a channel, for easy testing later. - read := make(chan JSONRPCMessage) + read := make(chan jsonrpc.Message) go func() { for range 2 { msg, _ := tport.Read(ctx) @@ -34,7 +37,7 @@ func TestBatchFraming(t *testing.T) { }() // The first write should not yet be observed by the reader. - tport.Write(ctx, &JSONRPCRequest{ID: jsonrpc2.Int64ID(1), Method: "test"}) + tport.Write(ctx, &jsonrpc.Request{ID: jsonrpc2.Int64ID(1), Method: "test"}) select { case got := <-read: t.Fatalf("after one write, got message %v", got) @@ -42,11 +45,76 @@ func TestBatchFraming(t *testing.T) { } // ...but the second write causes both messages to be observed. - tport.Write(ctx, &JSONRPCRequest{ID: jsonrpc2.Int64ID(2), Method: "test"}) + tport.Write(ctx, &jsonrpc.Request{ID: jsonrpc2.Int64ID(2), Method: "test"}) for _, want := range []int64{1, 2} { got := <-read - if got := got.(*JSONRPCRequest).ID.Raw(); got != want { + if got := got.(*jsonrpc.Request).ID.Raw(); got != want { t.Errorf("got message #%d, want #%d", got, want) } } } + +func TestIOConnRead(t *testing.T) { + tests := []struct { + name string + input string + want string + protocolVersion string + }{ + { + name: "valid json input", + input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}}`, + want: "", + }, + { + name: "newline at the end of first valid json input", + input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}} + `, + want: "", + }, + { + name: "bad data at the end of first valid json input", + input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}},`, + want: "invalid trailing data at the end of stream", + }, + { + name: "batching unknown protocol", + input: `[{"jsonrpc":"2.0","id":1,"method":"test1"},{"jsonrpc":"2.0","id":2,"method":"test2"}]`, + want: "", + protocolVersion: "", + }, + { + name: "batching old protocol", + input: `[{"jsonrpc":"2.0","id":1,"method":"test1"},{"jsonrpc":"2.0","id":2,"method":"test2"}]`, + want: "", + protocolVersion: protocolVersion20241105, + }, + { + name: "batching new protocol", + input: `[{"jsonrpc":"2.0","id":1,"method":"test1"},{"jsonrpc":"2.0","id":2,"method":"test2"}]`, + want: "JSON-RPC batching is not supported in 2025-06-18 and later (request version: 2025-06-18)", + protocolVersion: protocolVersion20250618, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := newIOConn(rwc{ + rc: io.NopCloser(strings.NewReader(tt.input)), + }) + if tt.protocolVersion != "" { + tr.sessionUpdated(ServerSessionState{ + InitializeParams: &InitializeParams{ + ProtocolVersion: tt.protocolVersion, + }, + }) + } + _, err := tr.Read(context.Background()) + if err == nil && tt.want != "" { + t.Errorf("ioConn.Read() got nil error but wanted %v", tt.want) + } + if err != nil && err.Error() != tt.want { + t.Errorf("ioConn.Read() = %v, want %v", err.Error(), tt.want) + } + }) + } +} diff --git a/mcp/util.go b/mcp/util.go index 82d87940..102c0885 100644 --- a/mcp/util.go +++ b/mcp/util.go @@ -6,12 +6,6 @@ package mcp import ( "crypto/rand" - "encoding/json" - "fmt" - "reflect" - "sync" - - "github.com/modelcontextprotocol/go-sdk/internal/util" ) func assert(cond bool, msg string) { @@ -33,114 +27,3 @@ func randText() string { } return string(src) } - -// marshalStructWithMap marshals its first argument to JSON, treating the field named -// mapField as an embedded map. The first argument must be a pointer to -// a struct. The underlying type of mapField must be a map[string]any, and it must have -// an "omitempty" json tag. -// -// For example, given this struct: -// -// type S struct { -// A int -// Extra map[string] any `json:,omitempty` -// } -// -// and this value: -// -// s := S{A: 1, Extra: map[string]any{"B": 2}} -// -// the call marshalJSONWithMap(s, "Extra") would return -// -// {"A": 1, "B": 2} -// -// It is an error if the map contains the same key as another struct field's -// JSON name. -// -// marshalStructWithMap calls json.Marshal on a value of type T, so T must not -// have a MarshalJSON method that calls this function, on pain of infinite regress. -// -// TODO: avoid this restriction on T by forcing it to marshal in a default way. -// See https://go.dev/play/p/EgXKJHxEx_R. -func marshalStructWithMap[T any](s *T, mapField string) ([]byte, error) { - // Marshal the struct and the map separately, and concatenate the bytes. - // This strategy is dramatically less complicated than - // constructing a synthetic struct or map with the combined keys. - if s == nil { - return []byte("null"), nil - } - s2 := *s - vMapField := reflect.ValueOf(&s2).Elem().FieldByName(mapField) - mapVal := vMapField.Interface().(map[string]any) - - // Check for duplicates. - names := jsonNames(reflect.TypeFor[T]()) - for key := range mapVal { - if names[key] { - return nil, fmt.Errorf("map key %q duplicates struct field", key) - } - } - - // Clear the map field, relying on the omitempty tag to omit it. - vMapField.Set(reflect.Zero(vMapField.Type())) - structBytes, err := json.Marshal(s2) - if err != nil { - return nil, fmt.Errorf("marshalStructWithMap(%+v): %w", s, err) - } - if len(mapVal) == 0 { - return structBytes, nil - } - mapBytes, err := json.Marshal(mapVal) - if err != nil { - return nil, err - } - if len(structBytes) == 2 { // must be "{}" - return mapBytes, nil - } - // "{X}" + "{Y}" => "{X,Y}" - res := append(structBytes[:len(structBytes)-1], ',') - res = append(res, mapBytes[1:]...) - return res, nil -} - -// unmarshalStructWithMap is the inverse of marshalStructWithMap. -// T has the same restrictions as in that function. -func unmarshalStructWithMap[T any](data []byte, v *T, mapField string) error { - // Unmarshal into the struct, ignoring unknown fields. - if err := json.Unmarshal(data, v); err != nil { - return err - } - // Unmarshal into the map. - m := map[string]any{} - if err := json.Unmarshal(data, &m); err != nil { - return err - } - // Delete from the map the fields of the struct. - for n := range jsonNames(reflect.TypeFor[T]()) { - delete(m, n) - } - if len(m) != 0 { - reflect.ValueOf(v).Elem().FieldByName(mapField).Set(reflect.ValueOf(m)) - } - return nil -} - -var jsonNamesMap sync.Map // from reflect.Type to map[string]bool - -// jsonNames returns the set of JSON object keys that t will marshal into. -// t must be a struct type. -func jsonNames(t reflect.Type) map[string]bool { - // Lock not necessary: at worst we'll duplicate work. - if val, ok := jsonNamesMap.Load(t); ok { - return val.(map[string]bool) - } - m := map[string]bool{} - for i := range t.NumField() { - info := util.FieldJSONInfo(t.Field(i)) - if !info.Omit { - m[info.Name] = true - } - } - jsonNamesMap.Store(t, m) - return m -} diff --git a/mcp/util_test.go b/mcp/util_test.go deleted file mode 100644 index f2cb0f5c..00000000 --- a/mcp/util_test.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -package mcp - -import ( - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" -) - -func TestMarshalStructWithMap(t *testing.T) { - type S struct { - A int - B string `json:"b,omitempty"` - u bool - M map[string]any `json:",omitempty"` - } - t.Run("basic", func(t *testing.T) { - s := S{A: 1, B: "two", M: map[string]any{"!@#": true}} - got, err := marshalStructWithMap(&s, "M") - if err != nil { - t.Fatal(err) - } - want := `{"A":1,"b":"two","!@#":true}` - if g := string(got); g != want { - t.Errorf("\ngot %s\nwant %s", g, want) - } - - var un S - if err := unmarshalStructWithMap(got, &un, "M"); err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(s, un, cmpopts.IgnoreUnexported(S{})); diff != "" { - t.Errorf("mismatch (-want, +got):\n%s", diff) - } - }) - t.Run("duplicate", func(t *testing.T) { - s := S{A: 1, B: "two", M: map[string]any{"b": "dup"}} - _, err := marshalStructWithMap(&s, "M") - if err == nil || !strings.Contains(err.Error(), "duplicate") { - t.Errorf("got %v, want error with 'duplicate'", err) - } - }) -} diff --git a/oauthex/oauthex.go b/oauthex/oauthex.go new file mode 100644 index 00000000..eb3a3e78 --- /dev/null +++ b/oauthex/oauthex.go @@ -0,0 +1,10 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package oauthex implements extensions to OAuth2. +package oauthex + +import "github.com/modelcontextprotocol/go-sdk/internal/oauthex" + +type ProtectedResourceMetadata = oauthex.ProtectedResourceMetadata