Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ mypy:

.PHONY: tests
tests:
uv run pytest
uv run --extra test pytest

.PHONY: coverage
coverage:
Expand Down
17 changes: 12 additions & 5 deletions src/mcpadapt/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.client.websocket import websocket_client


TRANSPORTS = {
"sse": sse_client,
"streamable-http": streamablehttp_client,
"ws": websocket_client,
}


class ToolAdapter(ABC):
Expand Down Expand Up @@ -100,14 +108,13 @@ async def mcptools(
# Create a deep copy to avoid modifying the original dict
client_params = copy.deepcopy(serverparams)
transport = client_params.pop("transport", "sse")
if transport == "sse":
client = sse_client(**client_params)
elif transport == "streamable-http":
client = streamablehttp_client(**client_params)
if transport in TRANSPORTS:
client = TRANSPORTS[transport](**client_params)
else:
raise ValueError(
f"Invalid transport, expected sse or streamable-http found `{transport}`"
f"Invalid transport, expected {list(TRANSPORTS.keys())} found `{transport}`"
)

else:
raise ValueError(
f"Invalid serverparams, expected StdioServerParameters or dict found `{type(serverparams)}`"
Expand Down
170 changes: 170 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,131 @@ async def echo_streamable_http_server(echo_server_streamable_http_script):
process.wait()


@pytest.fixture
def echo_server_websocket_script():
return dedent(
'''
import asyncio
import json
import websockets

async def handle_mcp_client(websocket):
"""Handle MCP protocol over websocket"""
try:
async for message in websocket:
data = json.loads(message)

if data.get("method") == "initialize":
# Send initialize response
response = {
"jsonrpc": "2.0",
"id": data.get("id"),
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": "Echo Server",
"version": "1.0.0"
}
}
}
await websocket.send(json.dumps(response))

elif data.get("method") == "tools/list":
# Send tools list response
response = {
"jsonrpc": "2.0",
"id": data.get("id"),
"result": {
"tools": [{
"name": "echo_tool",
"description": "Echo the input text",
"inputSchema": {
"type": "object",
"properties": {
"text": {"type": "string"}
},
"required": ["text"]
}
}]
}
}
await websocket.send(json.dumps(response))

elif data.get("method") == "tools/call":
# Handle tool call
tool_name = data.get("params", {}).get("name")
arguments = data.get("params", {}).get("arguments", {})

if tool_name == "echo_tool":
text = arguments.get("text", "")
response = {
"jsonrpc": "2.0",
"id": data.get("id"),
"result": {
"content": [{
"type": "text",
"text": f"Echo: {text}"
}]
}
}
await websocket.send(json.dumps(response))

except websockets.exceptions.ConnectionClosed:
pass

async def main():
server = await websockets.serve(handle_mcp_client, "127.0.0.1", 8001)
await server.wait_closed()

if __name__ == "__main__":
asyncio.run(main())
'''
)


@pytest.fixture
def echo_websocket_server(echo_server_websocket_script):
import subprocess

# Start the WebSocket server process
process = subprocess.Popen(
["python", "-c", echo_server_websocket_script],
)

# Give the server a moment to start up
time.sleep(1)

try:
yield {"url": "ws://127.0.0.1:8001/ws", "transport": "ws"}
finally:
# Clean up the process when test is done
process.kill()
process.wait()


@pytest.fixture
async def echo_websocket_server_async(echo_server_websocket_script):
import subprocess

# Start the WebSocket server process
process = subprocess.Popen(
["python", "-c", echo_server_websocket_script],
)

# Give the server a moment to start up
time.sleep(1)

try:
yield {"url": "ws://127.0.0.1:8001/ws", "transport": "ws"}
finally:
# Clean up the process when test is done
process.kill()
process.wait()


@pytest.fixture
def slow_start_server_script():
return dedent(
Expand Down Expand Up @@ -343,6 +468,51 @@ async def test_basic_async_streamable_http(echo_streamable_http_server):
assert (await tools[0]({"text": "hello"})).content[0].text == "Echo: hello"


def test_basic_sync_websocket(echo_websocket_server):
ws_serverparams = echo_websocket_server
with MCPAdapt(
ws_serverparams,
DummyAdapter(),
) as tools:
assert len(tools) == 1
assert tools[0]({"text": "hello"}).content[0].text == "Echo: hello"


def test_basic_sync_multiple_websocket(echo_websocket_server):
ws_serverparams = echo_websocket_server
with MCPAdapt(
[ws_serverparams, ws_serverparams],
DummyAdapter(),
) as tools:
assert len(tools) == 2
assert tools[0]({"text": "hello"}).content[0].text == "Echo: hello"
assert tools[1]({"text": "world"}).content[0].text == "Echo: world"


async def test_basic_async_websocket(echo_websocket_server):
ws_serverparams = echo_websocket_server
async with MCPAdapt(
ws_serverparams,
DummyAdapter(),
) as tools:
assert len(tools) == 1
mcp_tool_call_result = await tools[0]({"text": "hello"})
assert mcp_tool_call_result.content[0].text == "Echo: hello"


async def test_basic_async_multiple_websocket(echo_websocket_server):
ws_serverparams = echo_websocket_server
async with MCPAdapt(
[ws_serverparams, ws_serverparams],
DummyAdapter(),
) as tools:
assert len(tools) == 2
mcp_tool_call_result = await tools[0]({"text": "hello"})
assert mcp_tool_call_result.content[0].text == "Echo: hello"
mcp_tool_call_result = await tools[1]({"text": "world"})
assert mcp_tool_call_result.content[0].text == "Echo: world"


def test_connect_timeout(slow_start_server_script):
"""Test that connect_timeout raises TimeoutError when server starts slowly"""
with pytest.raises(
Expand Down