From 0ea92e6b3c6fd557e4006c7d12bd3cab7466d475 Mon Sep 17 00:00:00 2001 From: Andres March <> Date: Tue, 24 Jun 2025 11:22:53 -0400 Subject: [PATCH] Make "resource" optional on earlier protocols --- src/mcp/client/auth.py | 38 ++++++++++++++++-- tests/client/test_auth.py | 82 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 359e0585f..769e9b4c8 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -95,6 +95,7 @@ class OAuthContext: protected_resource_metadata: ProtectedResourceMetadata | None = None oauth_metadata: OAuthMetadata | None = None auth_server_url: str | None = None + protocol_version: str | None = None # Client registration client_info: OAuthClientInformationFull | None = None @@ -154,6 +155,25 @@ def get_resource_url(self) -> str: return resource + def should_include_resource_param(self, protocol_version: str | None = None) -> bool: + """Determine if the resource parameter should be included in OAuth requests. + + Returns True if: + - Protected resource metadata is available, OR + - MCP-Protocol-Version header is 2025-06-18 or later + """ + # If we have protected resource metadata, include the resource param + if self.protected_resource_metadata is not None: + return True + + # If no protocol version provided, don't include resource param + if not protocol_version: + return False + + # Check if protocol version is 2025-06-18 or later + # Version format is YYYY-MM-DD, so string comparison works + return protocol_version >= "2025-06-18" + class OAuthClientProvider(httpx.Auth): """ @@ -320,9 +340,12 @@ async def _perform_authorization(self) -> tuple[str, str]: "state": state, "code_challenge": pkce_params.code_challenge, "code_challenge_method": "S256", - "resource": self.context.get_resource_url(), # RFC 8707 } + # Only include resource param if conditions are met + if self.context.should_include_resource_param(self.context.protocol_version): + auth_params["resource"] = self.context.get_resource_url() # RFC 8707 + if self.context.client_metadata.scope: auth_params["scope"] = self.context.client_metadata.scope @@ -358,9 +381,12 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), "client_id": self.context.client_info.client_id, "code_verifier": code_verifier, - "resource": self.context.get_resource_url(), # RFC 8707 } + # Only include resource param if conditions are met + if self.context.should_include_resource_param(self.context.protocol_version): + token_data["resource"] = self.context.get_resource_url() # RFC 8707 + if self.context.client_info.client_secret: token_data["client_secret"] = self.context.client_info.client_secret @@ -409,9 +435,12 @@ async def _refresh_token(self) -> httpx.Request: "grant_type": "refresh_token", "refresh_token": self.context.current_tokens.refresh_token, "client_id": self.context.client_info.client_id, - "resource": self.context.get_resource_url(), # RFC 8707 } + # Only include resource param if conditions are met + if self.context.should_include_resource_param(self.context.protocol_version): + refresh_data["resource"] = self.context.get_resource_url() # RFC 8707 + if self.context.client_info.client_secret: refresh_data["client_secret"] = self.context.client_info.client_secret @@ -457,6 +486,9 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. if not self._initialized: await self._initialize() + # Capture protocol version from request headers + self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) + # Perform OAuth flow if not authenticated if not self.context.is_token_valid(): try: diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index d87410d00..8e6b4f54d 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -13,6 +13,7 @@ OAuthClientInformationFull, OAuthClientMetadata, OAuthToken, + ProtectedResourceMetadata, ) @@ -434,6 +435,87 @@ async def test_refresh_token_request(self, oauth_provider, valid_tokens): assert "client_secret=test_secret" in content +class TestProtectedResourceMetadata: + """Test protected resource handling.""" + + @pytest.mark.anyio + async def test_resource_param_included_with_recent_protocol_version(self, oauth_provider: OAuthClientProvider): + """Test resource parameter is included for protocol version >= 2025-06-18.""" + # Set protocol version to 2025-06-18 + oauth_provider.context.protocol_version = "2025-06-18" + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + # Test in token exchange + request = await oauth_provider._exchange_token("test_code", "test_verifier") + content = request.content.decode() + assert "resource=" in content + # Check URL-encoded resource parameter + from urllib.parse import quote + + expected_resource = quote(oauth_provider.context.get_resource_url(), safe="") + assert f"resource={expected_resource}" in content + + # Test in refresh token + oauth_provider.context.current_tokens = OAuthToken( + access_token="test_access", + token_type="Bearer", + refresh_token="test_refresh", + ) + refresh_request = await oauth_provider._refresh_token() + refresh_content = refresh_request.content.decode() + assert "resource=" in refresh_content + + @pytest.mark.anyio + async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider): + """Test resource parameter is excluded for protocol version < 2025-06-18.""" + # Set protocol version to older version + oauth_provider.context.protocol_version = "2025-03-26" + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + # Test in token exchange + request = await oauth_provider._exchange_token("test_code", "test_verifier") + content = request.content.decode() + assert "resource=" not in content + + # Test in refresh token + oauth_provider.context.current_tokens = OAuthToken( + access_token="test_access", + token_type="Bearer", + refresh_token="test_refresh", + ) + refresh_request = await oauth_provider._refresh_token() + refresh_content = refresh_request.content.decode() + assert "resource=" not in refresh_content + + @pytest.mark.anyio + async def test_resource_param_included_with_protected_resource_metadata(self, oauth_provider: OAuthClientProvider): + """Test resource parameter is always included when protected resource metadata exists.""" + # Set old protocol version but with protected resource metadata + oauth_provider.context.protocol_version = "2025-03-26" + oauth_provider.context.protected_resource_metadata = ProtectedResourceMetadata( + resource=AnyHttpUrl("https://api.example.com/v1/mcp"), + authorization_servers=[AnyHttpUrl("https://api.example.com")], + ) + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + # Test in token exchange + request = await oauth_provider._exchange_token("test_code", "test_verifier") + content = request.content.decode() + assert "resource=" in content + + class TestAuthFlow: """Test the auth flow in httpx."""