diff --git a/Makefile b/Makefile index b9e3fa5..8b067af 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ mypy: .PHONY: tests tests: - uv run pytest + uv run --extra test pytest .PHONY: coverage coverage: diff --git a/src/mcpadapt/core.py b/src/mcpadapt/core.py index cdf2d76..b16b562 100644 --- a/src/mcpadapt/core.py +++ b/src/mcpadapt/core.py @@ -18,6 +18,14 @@ from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client +from mcp.client.websocket import websocket_client + + +TRANSPORTS = { + "sse": sse_client, + "streamable-http": streamablehttp_client, + "ws": websocket_client, +} class ToolAdapter(ABC): @@ -100,14 +108,13 @@ async def mcptools( # Create a deep copy to avoid modifying the original dict client_params = copy.deepcopy(serverparams) transport = client_params.pop("transport", "sse") - if transport == "sse": - client = sse_client(**client_params) - elif transport == "streamable-http": - client = streamablehttp_client(**client_params) + if transport in TRANSPORTS: + client = TRANSPORTS[transport](**client_params) else: raise ValueError( - f"Invalid transport, expected sse or streamable-http found `{transport}`" + f"Invalid transport, expected {list(TRANSPORTS.keys())} found `{transport}`" ) + else: raise ValueError( f"Invalid serverparams, expected StdioServerParameters or dict found `{type(serverparams)}`" diff --git a/tests/test_core.py b/tests/test_core.py index f66170e..cce7b1a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -153,6 +153,131 @@ async def echo_streamable_http_server(echo_server_streamable_http_script): process.wait() +@pytest.fixture +def echo_server_websocket_script(): + return dedent( + ''' + import asyncio + import json + import websockets + + async def handle_mcp_client(websocket): + """Handle MCP protocol over websocket""" + try: + async for message in websocket: + data = json.loads(message) + + if data.get("method") == "initialize": + # Send initialize response + response = { + "jsonrpc": "2.0", + "id": data.get("id"), + "result": { + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {} + }, + "serverInfo": { + "name": "Echo Server", + "version": "1.0.0" + } + } + } + await websocket.send(json.dumps(response)) + + elif data.get("method") == "tools/list": + # Send tools list response + response = { + "jsonrpc": "2.0", + "id": data.get("id"), + "result": { + "tools": [{ + "name": "echo_tool", + "description": "Echo the input text", + "inputSchema": { + "type": "object", + "properties": { + "text": {"type": "string"} + }, + "required": ["text"] + } + }] + } + } + await websocket.send(json.dumps(response)) + + elif data.get("method") == "tools/call": + # Handle tool call + tool_name = data.get("params", {}).get("name") + arguments = data.get("params", {}).get("arguments", {}) + + if tool_name == "echo_tool": + text = arguments.get("text", "") + response = { + "jsonrpc": "2.0", + "id": data.get("id"), + "result": { + "content": [{ + "type": "text", + "text": f"Echo: {text}" + }] + } + } + await websocket.send(json.dumps(response)) + + except websockets.exceptions.ConnectionClosed: + pass + + async def main(): + server = await websockets.serve(handle_mcp_client, "127.0.0.1", 8001) + await server.wait_closed() + + if __name__ == "__main__": + asyncio.run(main()) + ''' + ) + + +@pytest.fixture +def echo_websocket_server(echo_server_websocket_script): + import subprocess + + # Start the WebSocket server process + process = subprocess.Popen( + ["python", "-c", echo_server_websocket_script], + ) + + # Give the server a moment to start up + time.sleep(1) + + try: + yield {"url": "ws://127.0.0.1:8001/ws", "transport": "ws"} + finally: + # Clean up the process when test is done + process.kill() + process.wait() + + +@pytest.fixture +async def echo_websocket_server_async(echo_server_websocket_script): + import subprocess + + # Start the WebSocket server process + process = subprocess.Popen( + ["python", "-c", echo_server_websocket_script], + ) + + # Give the server a moment to start up + time.sleep(1) + + try: + yield {"url": "ws://127.0.0.1:8001/ws", "transport": "ws"} + finally: + # Clean up the process when test is done + process.kill() + process.wait() + + @pytest.fixture def slow_start_server_script(): return dedent( @@ -343,6 +468,51 @@ async def test_basic_async_streamable_http(echo_streamable_http_server): assert (await tools[0]({"text": "hello"})).content[0].text == "Echo: hello" +def test_basic_sync_websocket(echo_websocket_server): + ws_serverparams = echo_websocket_server + with MCPAdapt( + ws_serverparams, + DummyAdapter(), + ) as tools: + assert len(tools) == 1 + assert tools[0]({"text": "hello"}).content[0].text == "Echo: hello" + + +def test_basic_sync_multiple_websocket(echo_websocket_server): + ws_serverparams = echo_websocket_server + with MCPAdapt( + [ws_serverparams, ws_serverparams], + DummyAdapter(), + ) as tools: + assert len(tools) == 2 + assert tools[0]({"text": "hello"}).content[0].text == "Echo: hello" + assert tools[1]({"text": "world"}).content[0].text == "Echo: world" + + +async def test_basic_async_websocket(echo_websocket_server): + ws_serverparams = echo_websocket_server + async with MCPAdapt( + ws_serverparams, + DummyAdapter(), + ) as tools: + assert len(tools) == 1 + mcp_tool_call_result = await tools[0]({"text": "hello"}) + assert mcp_tool_call_result.content[0].text == "Echo: hello" + + +async def test_basic_async_multiple_websocket(echo_websocket_server): + ws_serverparams = echo_websocket_server + async with MCPAdapt( + [ws_serverparams, ws_serverparams], + DummyAdapter(), + ) as tools: + assert len(tools) == 2 + mcp_tool_call_result = await tools[0]({"text": "hello"}) + assert mcp_tool_call_result.content[0].text == "Echo: hello" + mcp_tool_call_result = await tools[1]({"text": "world"}) + assert mcp_tool_call_result.content[0].text == "Echo: world" + + def test_connect_timeout(slow_start_server_script): """Test that connect_timeout raises TimeoutError when server starts slowly""" with pytest.raises(