Skip to content
Closed
16 changes: 3 additions & 13 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import re
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urljoin, urlparse
Expand Down Expand Up @@ -62,21 +61,12 @@ async def sse_reader(
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
url_parsed = urlparse(url)

base_path = re.search(
r"https?://[^/]+/(.+?)(?:/mcp)?/sse$", url
)
base_path = (
base_path.group(1) if base_path else ""
)
endpoint_url = urljoin(
url_parsed.scheme + "://" + url_parsed.netloc, # noqa: E501
base_path + sse.data
)
endpoint_url = urljoin(url, sse.data)
logger.info(
f"Received endpoint URL: {endpoint_url}"
)

url_parsed = urlparse(url)

endpoint_parsed = urlparse(endpoint_url)
if (
Expand Down
7 changes: 6 additions & 1 deletion src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async def handle_sse(request):
from urllib.parse import quote
from uuid import UUID, uuid4

import re
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError
Expand Down Expand Up @@ -95,7 +96,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

session_id = uuid4()
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
request_path = scope["path"]
match = re.match(r"^/([^/]+(?:/mcp)?)/sse$", request_path)
mount_prefix = match.group(1) if match else ""
session_uri = f"/{quote(mount_prefix)}{quote(self._endpoint)}?session_id={session_id.hex}"

self._read_stream_writers[session_id] = read_stream_writer
logger.debug(f"Created new session with ID: {session_id}")

Expand Down
Loading