Skip to content
4 changes: 2 additions & 2 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
RequestId = str | int
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really the only change that is necessary.

It will try to validate int first, then str.

AnyFunction: TypeAlias = Callable[..., Any]


Expand Down Expand Up @@ -353,7 +353,7 @@ class ProgressNotificationParams(NotificationParams):
"""Total number of items to process (or total progress required), if known."""
message: str | None = None
"""
Message related to progress. This should provide relevant human readable
Message related to progress. This should provide relevant human readable
progress information.
"""
model_config = ConfigDict(extra="allow")
Expand Down
16 changes: 16 additions & 0 deletions tests/shared/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import httpx
import pytest
import uvicorn
from inline_snapshot import snapshot
from pydantic import AnyUrl
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount, Route

import mcp.types as types
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.server import Server
Expand Down Expand Up @@ -503,3 +505,17 @@ async def test_request_context_isolation(context_server: None, server_url: str)
assert ctx["request_id"] == f"request-{i}"
assert ctx["headers"].get("x-request-id") == f"request-{i}"
assert ctx["headers"].get("x-custom-value") == f"value-{i}"


def test_sse_message_id_coercion():
"""Test that string message IDs that look like integers are parsed as integers.

See <https://github.com/modelcontextprotocol/python-sdk/pull/851> for more details.
"""
json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}'
msg = types.JSONRPCMessage.model_validate_json(json_message)
assert msg == snapshot(
types.JSONRPCMessage(
root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)
)
)
Loading