|
8 | 8 | import socket
|
9 | 9 | import time
|
10 | 10 | from collections.abc import Generator
|
| 11 | +from typing import Any |
11 | 12 |
|
12 | 13 | import anyio
|
13 | 14 | import httpx
|
|
33 | 34 | StreamId,
|
34 | 35 | )
|
35 | 36 | from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
| 37 | +from mcp.shared.context import RequestContext |
36 | 38 | from mcp.shared.exceptions import McpError
|
37 | 39 | from mcp.shared.message import (
|
38 | 40 | ClientMessageMetadata,
|
@@ -139,6 +141,11 @@ async def handle_list_tools() -> list[Tool]:
|
139 | 141 | description="A long-running tool that sends periodic notifications",
|
140 | 142 | inputSchema={"type": "object", "properties": {}},
|
141 | 143 | ),
|
| 144 | + Tool( |
| 145 | + name="test_sampling_tool", |
| 146 | + description="A tool that triggers server-side sampling", |
| 147 | + inputSchema={"type": "object", "properties": {}}, |
| 148 | + ), |
142 | 149 | ]
|
143 | 150 |
|
144 | 151 | @self.call_tool()
|
@@ -174,6 +181,34 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
|
174 | 181 |
|
175 | 182 | return [TextContent(type="text", text="Completed!")]
|
176 | 183 |
|
| 184 | + elif name == "test_sampling_tool": |
| 185 | + # Test sampling by requesting the client to sample a message |
| 186 | + sampling_result = await ctx.session.create_message( |
| 187 | + messages=[ |
| 188 | + types.SamplingMessage( |
| 189 | + role="user", |
| 190 | + content=types.TextContent( |
| 191 | + type="text", text="Server needs client sampling" |
| 192 | + ), |
| 193 | + ) |
| 194 | + ], |
| 195 | + max_tokens=100, |
| 196 | + related_request_id=ctx.request_id, |
| 197 | + ) |
| 198 | + |
| 199 | + # Return the sampling result in the tool response |
| 200 | + response = ( |
| 201 | + sampling_result.content.text |
| 202 | + if sampling_result.content.type == "text" |
| 203 | + else None |
| 204 | + ) |
| 205 | + return [ |
| 206 | + TextContent( |
| 207 | + type="text", |
| 208 | + text=f"Response from sampling: {response}", |
| 209 | + ) |
| 210 | + ] |
| 211 | + |
177 | 212 | return [TextContent(type="text", text=f"Called {name}")]
|
178 | 213 |
|
179 | 214 |
|
@@ -754,7 +789,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
|
754 | 789 | """Test client tool invocation."""
|
755 | 790 | # First list tools
|
756 | 791 | tools = await initialized_client_session.list_tools()
|
757 |
| - assert len(tools.tools) == 3 |
| 792 | + assert len(tools.tools) == 4 |
758 | 793 | assert tools.tools[0].name == "test_tool"
|
759 | 794 |
|
760 | 795 | # Call the tool
|
@@ -795,7 +830,7 @@ async def test_streamablehttp_client_session_persistence(
|
795 | 830 |
|
796 | 831 | # Make multiple requests to verify session persistence
|
797 | 832 | tools = await session.list_tools()
|
798 |
| - assert len(tools.tools) == 3 |
| 833 | + assert len(tools.tools) == 4 |
799 | 834 |
|
800 | 835 | # Read a resource
|
801 | 836 | resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
|
@@ -826,7 +861,7 @@ async def test_streamablehttp_client_json_response(
|
826 | 861 |
|
827 | 862 | # Check tool listing
|
828 | 863 | tools = await session.list_tools()
|
829 |
| - assert len(tools.tools) == 3 |
| 864 | + assert len(tools.tools) == 4 |
830 | 865 |
|
831 | 866 | # Call a tool and verify JSON response handling
|
832 | 867 | result = await session.call_tool("test_tool", {})
|
@@ -905,7 +940,7 @@ async def test_streamablehttp_client_session_termination(
|
905 | 940 |
|
906 | 941 | # Make a request to confirm session is working
|
907 | 942 | tools = await session.list_tools()
|
908 |
| - assert len(tools.tools) == 3 |
| 943 | + assert len(tools.tools) == 4 |
909 | 944 |
|
910 | 945 | headers = {}
|
911 | 946 | if captured_session_id:
|
@@ -1054,3 +1089,71 @@ async def run_tool():
|
1054 | 1089 | assert not any(
|
1055 | 1090 | n in captured_notifications_pre for n in captured_notifications
|
1056 | 1091 | )
|
| 1092 | + |
| 1093 | + |
| 1094 | +@pytest.mark.anyio |
| 1095 | +async def test_streamablehttp_server_sampling(basic_server, basic_server_url): |
| 1096 | + """Test server-initiated sampling request through streamable HTTP transport.""" |
| 1097 | + print("Testing server sampling...") |
| 1098 | + # Variable to track if sampling callback was invoked |
| 1099 | + sampling_callback_invoked = False |
| 1100 | + captured_message_params = None |
| 1101 | + |
| 1102 | + # Define sampling callback that returns a mock response |
| 1103 | + async def sampling_callback( |
| 1104 | + context: RequestContext[ClientSession, Any], |
| 1105 | + params: types.CreateMessageRequestParams, |
| 1106 | + ) -> types.CreateMessageResult: |
| 1107 | + nonlocal sampling_callback_invoked, captured_message_params |
| 1108 | + sampling_callback_invoked = True |
| 1109 | + captured_message_params = params |
| 1110 | + message_received = ( |
| 1111 | + params.messages[0].content.text |
| 1112 | + if params.messages[0].content.type == "text" |
| 1113 | + else None |
| 1114 | + ) |
| 1115 | + |
| 1116 | + return types.CreateMessageResult( |
| 1117 | + role="assistant", |
| 1118 | + content=types.TextContent( |
| 1119 | + type="text", |
| 1120 | + text=f"Received message from server: {message_received}", |
| 1121 | + ), |
| 1122 | + model="test-model", |
| 1123 | + stopReason="endTurn", |
| 1124 | + ) |
| 1125 | + |
| 1126 | + # Create client with sampling callback |
| 1127 | + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( |
| 1128 | + read_stream, |
| 1129 | + write_stream, |
| 1130 | + _, |
| 1131 | + ): |
| 1132 | + async with ClientSession( |
| 1133 | + read_stream, |
| 1134 | + write_stream, |
| 1135 | + sampling_callback=sampling_callback, |
| 1136 | + ) as session: |
| 1137 | + # Initialize the session |
| 1138 | + result = await session.initialize() |
| 1139 | + assert isinstance(result, InitializeResult) |
| 1140 | + |
| 1141 | + # Call the tool that triggers server-side sampling |
| 1142 | + tool_result = await session.call_tool("test_sampling_tool", {}) |
| 1143 | + |
| 1144 | + # Verify the tool result contains the expected content |
| 1145 | + assert len(tool_result.content) == 1 |
| 1146 | + assert tool_result.content[0].type == "text" |
| 1147 | + assert ( |
| 1148 | + "Response from sampling: Received message from server" |
| 1149 | + in tool_result.content[0].text |
| 1150 | + ) |
| 1151 | + |
| 1152 | + # Verify sampling callback was invoked |
| 1153 | + assert sampling_callback_invoked |
| 1154 | + assert captured_message_params is not None |
| 1155 | + assert len(captured_message_params.messages) == 1 |
| 1156 | + assert ( |
| 1157 | + captured_message_params.messages[0].content.text |
| 1158 | + == "Server needs client sampling" |
| 1159 | + ) |
0 commit comments