From 1962f55e7b911646f234a23307395810b8ccf394 Mon Sep 17 00:00:00 2001 From: Guoxun Wei Date: Wed, 27 Aug 2025 10:06:46 +0800 Subject: [PATCH 1/4] support oauth --- docs/oauth-authentication.md | 454 ++++++++++++ internal/auth/oauth/endpoints.go | 915 +++++++++++++++++++++++++ internal/auth/oauth/endpoints_test.go | 436 ++++++++++++ internal/auth/oauth/middleware.go | 285 ++++++++ internal/auth/oauth/middleware_test.go | 251 +++++++ internal/auth/oauth/provider.go | 510 ++++++++++++++ internal/auth/oauth/provider_test.go | 381 ++++++++++ internal/auth/types.go | 130 ++++ internal/auth/types_test.go | 175 +++++ internal/config/config.go | 42 ++ internal/config/config_test.go | 27 + internal/config/validator.go | 12 +- internal/server/server.go | 128 +++- 13 files changed, 3741 insertions(+), 5 deletions(-) create mode 100644 docs/oauth-authentication.md create mode 100644 internal/auth/oauth/endpoints.go create mode 100644 internal/auth/oauth/endpoints_test.go create mode 100644 internal/auth/oauth/middleware.go create mode 100644 internal/auth/oauth/middleware_test.go create mode 100644 internal/auth/oauth/provider.go create mode 100644 internal/auth/oauth/provider_test.go create mode 100644 internal/auth/types.go create mode 100644 internal/auth/types_test.go create mode 100644 internal/config/config_test.go diff --git a/docs/oauth-authentication.md b/docs/oauth-authentication.md new file mode 100644 index 0000000..1588656 --- /dev/null +++ b/docs/oauth-authentication.md @@ -0,0 +1,454 @@ +# OAuth Authentication for AKS-MCP + +This document describes how to configure and use OAuth authentication with AKS-MCP. + +## Overview + +AKS-MCP now supports OAuth 2.1 authentication using Azure Active Directory as the authorization server. When enabled, OAuth authentication provides secure access control for MCP endpoints using Bearer tokens. + +## Features + +- **Azure AD Integration**: Uses Azure Active Directory as the OAuth authorization server +- **JWT Token Validation**: Validates JWT tokens with Azure AD signing keys +- **OAuth 2.0 Metadata Endpoints**: Provides standard OAuth metadata discovery endpoints +- **Dynamic Client Registration**: Supports RFC 7591 dynamic client registration +- **Token Introspection**: Implements RFC 7662 token introspection +- **Transport Support**: Works with both SSE and HTTP Streamable transports +- **Flexible Configuration**: Supports environment variables and command-line configuration + +## Environment Setup and Azure AD Configuration + +### Prerequisites + +Before setting up OAuth authentication, ensure you have: + +- Azure CLI installed and configured (`az login`) +- An Azure subscription with appropriate permissions to create applications +- Azure Active Directory tenant access + +### Important: Environment Variables Shared Between OAuth and Azure CLI + +AKS-MCP uses the same environment variables (`AZURE_TENANT_ID`, `AZURE_CLIENT_ID`) for both OAuth authentication and Azure CLI operations. This design provides configuration simplicity but requires careful permission setup: + +**When `AZURE_CLIENT_ID` is set:** +- OAuth: Used for validating user tokens accessing the MCP server +- Azure CLI: Used for managed identity/workload identity authentication to access Azure resources + +**Permission Requirements:** +- The Azure AD application must have both **OAuth permissions** (for user authentication) AND **Azure resource permissions** (for az CLI operations) +- Missing either set of permissions will cause authentication failures + +### Step 1: Create Azure AD Application + +#### Using Azure Portal (Recommended) + +1. **Navigate to Azure Portal** + - Go to https://portal.azure.com + - Sign in with your Azure account + +2. **Create App Registration** + ``` + Navigation: Azure Active Directory → App registrations → New registration + ``` + + Configure the following: + - **Name**: `AKS-MCP-OAuth` (or your preferred name) + - **Supported account types**: "Accounts in this organizational directory only" + - **Redirect URI Platform Options**: + +#### Supported Platform Types + +**✅ Mobile and desktop applications (Recommended)** +- **Platform**: "Mobile and desktop applications" +- **Redirect URIs**: + - `http://localhost:8000/oauth/callback` +- **Benefits**: + - Native support for PKCE (required by OAuth 2.1) + - No client secret required (public client) + - Better security for localhost redirects +- **Status**: ✅ **Confirmed working** + +**❌ Single-page application (SPA) - Not Recommended** +- **Platform**: "Single-page application (SPA)" +- **Redirect URIs**: Same as above +- **Benefits**: + - Designed for PKCE flow + - No client secret required +- **Critical Limitations**: + - **Token exchange restriction**: Azure AD error AADSTS9002327 - "Tokens issued for the 'Single-Page Application' client-type may only be redeemed via cross-origin requests" + - **Architecture mismatch**: SPA platform expects frontend JavaScript to handle token exchange, but AKS-MCP performs backend token exchange + - **CORS requirements**: Requires complex frontend-backend coordination for OAuth flow +- **Status**: ❌ **Not compatible with AKS-MCP's backend OAuth implementation** + +**❌ Web application** +- **Platform**: "Web" +- **Why not supported**: + - Requires client secret (confidential client) + - AKS-MCP implements public client flow without secrets + - PKCE handling may differ + +**Choose Platform Recommendation:** +1. **Primary**: Use "Mobile and desktop applications" (✅ confirmed working) +2. **Avoid**: "Single-page application" - incompatible with backend OAuth implementation (AADSTS9002327 error) +3. **Avoid**: "Web" platform due to client secret requirements + +3. **Record Essential Information** + From the "Overview" page, note: + - **Application (client) ID** - This is your `CLIENT_ID` + - **Directory (tenant) ID** - This is your `TENANT_ID` + +#### Using Azure CLI (Alternative) + +**For Mobile and desktop applications platform:** +```bash +# Create Azure AD application with public client platform +az ad app create --display-name "AKS-MCP-OAuth" \ + --public-client-redirect-uris "http://localhost:8000/oauth/callback" "http://localhost:6274/oauth/callback/debug" + +# Get application details +az ad app list --display-name "AKS-MCP-OAuth" --query "[0].{appId:appId,objectId:objectId}" + +# Get your tenant ID +az account show --query "tenantId" -o tsv +``` + +**For Single-page application platform:** +```bash +# Create Azure AD application with SPA platform +az ad app create --display-name "AKS-MCP-OAuth" \ + --spa-redirect-uris "http://localhost:8000/oauth/callback" "http://localhost:6274/oauth/callback/debug" + +# Get application details +az ad app list --display-name "AKS-MCP-OAuth" --query "[0].{appId:appId,objectId:objectId}" + +# Get your tenant ID +az account show --query "tenantId" -o tsv +``` + +### Step 2: Configure API Permissions + +**Critical: Both OAuth and Azure CLI require proper permissions** + +1. **Add Required API Permissions** + ``` + Navigation: Azure Active Directory → App registrations → [Your App] → API permissions + ``` + +2. **Add Azure Service Management Permission (Required for OAuth)** + - Click "Add a permission" + - Select "Microsoft APIs" → "Azure Service Management" + - Choose "Delegated permissions" + - Select `user_impersonation` + - Click "Add permissions" + +3. **Add Azure Resource Management Permissions (Required for Azure CLI)** + + When `AZURE_CLIENT_ID` is set, Azure CLI will use this application for authentication. Add these permissions based on your AKS-MCP access level: + + **For readonly access:** + - Microsoft Graph → Application permissions → `Directory.Read.All` + - Azure Service Management → Delegated permissions → `user_impersonation` + + **For readwrite/admin access:** + - Microsoft Graph → Application permissions → `Directory.Read.All` + - Azure Service Management → Delegated permissions → `user_impersonation` + - Consider adding specific Azure resource permissions based on your needs + +4. **Grant Admin Consent (Required)** + - Click "Grant admin consent for [Your Organization]" + - Confirm the consent + +**⚠️ Important Notes:** +- Without proper Azure CLI permissions, you'll see "Insufficient privileges" errors when AKS-MCP tries to access Azure resources +- The same application serves both OAuth authentication (user access to MCP) and Azure CLI authentication (MCP access to Azure) +- Test both OAuth flow AND Azure resource access after permission changes + +### Step 3: Environment Configuration + +Set the required environment variables: + +```bash +# Replace with your actual values from Step 1 +export AZURE_TENANT_ID="your-tenant-id" +export AZURE_CLIENT_ID="your-client-id" +export AZURE_SUBSCRIPTION_ID="your-subscription-id" # Optional, for AKS operations +``` + +**⚠️ Important: Dual Authentication Impact** + +When you set `AZURE_CLIENT_ID`, it affects both OAuth and Azure CLI authentication: + +1. **OAuth Authentication**: Validates user tokens for MCP server access +2. **Azure CLI Authentication**: AKS-MCP uses this client ID for managed identity authentication when accessing Azure resources + +**Common Issues:** +- If you only configured OAuth permissions, Azure CLI operations will fail with "Insufficient privileges" +- If you only configured Azure resource permissions, OAuth token validation may fail +- Solution: Ensure your Azure AD application has BOTH sets of permissions (see Step 2) + +**Testing Both Authentication Paths:** +```bash +# Test OAuth (should work after proper setup) +curl -H "Authorization: Bearer YOUR_TOKEN" http://localhost:8000/mcp + +# Test Azure CLI access (should work after proper permissions) +# This happens automatically when AKS-MCP tries to access Azure resources +./aks-mcp --oauth-enabled --access-level=readonly +``` + +**Note for MCP Inspector**: MCP Inspector requires the callback URL `http://localhost:6274/oauth/callback/debug` to be configured in your Azure AD application redirect URIs for proper OAuth testing. + +### Step 4: Start AKS-MCP with OAuth + +```bash +# Using HTTP Streamable transport with OAuth (recommended) +./aks-mcp \ + --transport=streamable-http \ + --port=8000 \ + --oauth-enabled \ + --oauth-tenant-id="$AZURE_TENANT_ID" \ + --oauth-client-id="$AZURE_CLIENT_ID" \ + --oauth-redirects="http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback/debug" \ + --access-level=readonly + +# Using SSE transport with OAuth (alternative) +./aks-mcp \ + --transport=sse \ + --port=8000 \ + --oauth-enabled \ + --oauth-tenant-id="$AZURE_TENANT_ID" \ + --oauth-client-id="$AZURE_CLIENT_ID" \ + --oauth-redirects="http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback/debug" \ + --access-level=readonly + +# Environment variables are automatically used if set +# You can also just use: +./aks-mcp --transport=streamable-http --port=8000 --oauth-enabled --access-level=readonly +``` + +## Configuration Options + +### Command Line Flags + +- `--oauth-enabled`: Enable OAuth authentication (default: false) +- `--oauth-tenant-id`: Azure AD tenant ID (or use AZURE_TENANT_ID env var) +- `--oauth-client-id`: Azure AD client ID (or use AZURE_CLIENT_ID env var) +- `--oauth-redirects`: Comma-separated list of allowed redirect URIs (required when OAuth enabled) + +**Note**: OAuth scopes are automatically configured to use `https://management.azure.com/.default` for optimal Azure AD compatibility. Custom scopes are not currently configurable via command line. + +### Example with Command Line Flags + +```bash +./aks-mcp --transport=sse --oauth-enabled=true \ + --oauth-tenant-id="12345678-1234-1234-1234-123456789012" \ + --oauth-client-id="87654321-4321-4321-4321-210987654321" \ + --oauth-redirects="http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback/debug" +``` + +**Note**: Scopes are automatically set to `https://management.azure.com/.default` and cannot be customized via command line. + +## OAuth Endpoints + +When OAuth is enabled, the following endpoints are available: + +### Metadata Endpoints (Unauthenticated) + +- `GET /.well-known/oauth-protected-resource` - OAuth 2.0 Protected Resource Metadata (RFC 9728) +- `GET /.well-known/oauth-authorization-server` - OAuth 2.0 Authorization Server Metadata (RFC 8414) +- `GET /.well-known/openid-configuration` - OpenID Connect Discovery (alias for authorization server metadata) +- `GET /health` - Health check endpoint + +### OAuth Flow Endpoints (Unauthenticated) + +- `GET /oauth2/v2.0/authorize` - Authorization endpoint proxy to Azure AD +- `POST /oauth2/v2.0/token` - Token exchange endpoint proxy to Azure AD +- `GET /oauth/callback` - Authorization Code flow callback handler +- `POST /oauth/register` - Dynamic Client Registration (RFC 7591) + +### Token Management (Unauthenticated for simplicity) + +- `POST /oauth/introspect` - Token Introspection (RFC 7662) + +### Authenticated MCP Endpoints + +When OAuth is enabled, these endpoints require Bearer token authentication: + +- **SSE Transport**: `GET /sse`, `POST /message` +- **HTTP Streamable Transport**: `POST /mcp` + +## Client Integration + +### Obtaining an Access Token + +Use the Azure AD OAuth flow to obtain an access token: + +```bash +# Example using Azure CLI (for testing) +az account get-access-token --resource https://management.azure.com/ --query accessToken -o tsv +``` + +### Making Authenticated Requests + +Include the Bearer token in the Authorization header: + +```bash +# Example authenticated request to SSE endpoint +curl -H "Authorization: Bearer YOUR_ACCESS_TOKEN" \ + -H "Accept: text/event-stream" \ + http://localhost:8000/sse + +# Example authenticated request to HTTP Streamable endpoint +curl -H "Authorization: Bearer YOUR_ACCESS_TOKEN" \ + -H "Content-Type: application/json" \ + -X POST http://localhost:8000/mcp \ + -d '{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}' +``` + +## Testing OAuth Integration + +### 1. Test OAuth Metadata + +```bash +# Get protected resource metadata +curl http://localhost:8000/.well-known/oauth-protected-resource + +# Get authorization server metadata +curl http://localhost:8000/.well-known/oauth-authorization-server +``` + +### 2. Test Dynamic Client Registration + +```bash +curl -X POST http://localhost:8000/oauth/register \ + -H "Content-Type: application/json" \ + -d '{ + "redirect_uris": ["http://localhost:3000/oauth/callback"], + "client_name": "Test MCP Client", + "grant_types": ["authorization_code"] + }' +``` + +### 3. Test Token Introspection + +```bash +curl -X POST http://localhost:8000/oauth/introspect \ + -H "Content-Type: application/x-www-form-urlencoded" \ + -d "token=YOUR_ACCESS_TOKEN" +``` + +## Security Considerations + +1. **HTTPS in Production**: Always use HTTPS in production environments +2. **Token Validation**: JWT tokens are validated against Azure AD signing keys +3. **Scope Validation**: Tokens must include required scopes +4. **Audience Validation**: Tokens must have the correct audience claim +5. **Redirect URI Validation**: Only configured redirect URIs are allowed + +## Troubleshooting + +### Common Issues + +#### Authentication and Token Issues +1. **Invalid Token**: Ensure the token is valid and not expired +2. **Wrong Audience**: Verify the token audience matches `https://management.azure.com` +3. **Missing Scopes**: Ensure the token includes `https://management.azure.com/.default` scope +4. **JWT Signature Validation Failed**: + - Check that Azure AD application platform is set correctly + - Verify tenant ID matches the issuer in the token + - Ensure token is using v2.0 format (from Azure Management API scope) + +#### Azure AD Application Configuration Issues +5. **Client ID Not Found**: Verify the Application (client) ID is correct +6. **Redirect URI Mismatch**: Ensure redirect URIs match exactly in Azure AD app registration +7. **Wrong Platform Type**: Use "Mobile and desktop applications", NOT "Web" or "Single-page application" +8. **Insufficient Permissions**: Verify both OAuth and Azure resource permissions are configured +9. **SPA Platform Incompatibility (AADSTS9002327)**: + - Error: "Tokens issued for the 'Single-Page Application' client-type may only be redeemed via cross-origin requests" + - Solution: Change Azure AD app platform to "Mobile and desktop applications" + - Cause: SPA platform requires frontend token exchange, incompatible with AKS-MCP's backend implementation + +#### Network and Endpoint Issues +10. **CORS Errors**: Check that redirect URIs are properly configured for localhost +11. **Network Issues**: Check connectivity to Azure AD endpoints +11. **Port Conflicts**: Ensure the configured port (default 8000) is available + +#### Scope and Permission Issues +12. **Scope Mixing Error**: + - Error: "scope can't be combined with resource-specific scopes" + - Solution: Our implementation automatically handles this by using only Azure Management API scope +13. **Resource Parameter Issues**: + - Azure AD doesn't support RFC 8707 resource parameter + - Our implementation works around this limitation automatically + +### Debug Logging + +Enable verbose logging for OAuth debugging: + +```bash +./aks-mcp --oauth-enabled=true --verbose +``` + +### Health Check + +Use the health endpoint to verify OAuth configuration: + +```bash +curl http://localhost:8000/health +``` + +Expected response with OAuth enabled: +```json +{ + "status": "healthy", + "oauth": { + "enabled": true + } +} +``` + +### Testing OAuth Flow Step by Step + +1. **Test Metadata Discovery**: +```bash +# Should return authorization server URLs +curl http://localhost:8000/.well-known/oauth-protected-resource + +# Should return PKCE support and endpoints +curl http://localhost:8000/.well-known/oauth-authorization-server +``` + +2. **Test Client Registration**: +```bash +curl -X POST http://localhost:8000/oauth/register \ + -H "Content-Type: application/json" \ + -d '{ + "redirect_uris": ["http://localhost:8000/oauth/callback"], + "client_name": "Test Client" + }' +``` + +3. **Test Authorization Flow**: + - Open browser to: `http://localhost:8000/oauth2/v2.0/authorize?response_type=code&client_id=YOUR_CLIENT_ID&redirect_uri=http://localhost:8000/oauth/callback&scope=https://management.azure.com/.default&code_challenge=CHALLENGE&code_challenge_method=S256&state=STATE` + +4. **Verify Token Validation**: +```bash +# Use a valid Azure AD token +curl -H "Authorization: Bearer YOUR_TOKEN" http://localhost:8000/mcp +``` + +## Migration from Non-OAuth + +To migrate from a non-OAuth AKS-MCP deployment: + +1. Update clients to obtain and include Bearer tokens +2. Enable OAuth on the server with `--oauth-enabled=true` +3. Configure Azure AD application and credentials +4. Test with a subset of clients before full migration +5. Monitor logs for authentication errors + +## Integration with MCP Inspector + +The MCP Inspector tool can be used to test OAuth-enabled AKS-MCP servers. Configure the Inspector's OAuth settings to match your AKS-MCP OAuth configuration for testing. + +For more information, see the MCP OAuth specification and Azure AD documentation. \ No newline at end of file diff --git a/internal/auth/oauth/endpoints.go b/internal/auth/oauth/endpoints.go new file mode 100644 index 0000000..c4217b9 --- /dev/null +++ b/internal/auth/oauth/endpoints.go @@ -0,0 +1,915 @@ +package oauth + +import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/Azure/aks-mcp/internal/auth" + "github.com/Azure/aks-mcp/internal/config" +) + +// EndpointManager manages OAuth-related HTTP endpoints +type EndpointManager struct { + provider *AzureOAuthProvider + config *auth.OAuthConfig +} + +// NewEndpointManager creates a new OAuth endpoint manager +func NewEndpointManager(provider *AzureOAuthProvider, config *auth.OAuthConfig) *EndpointManager { + return &EndpointManager{ + provider: provider, + config: config, + } +} + +// setCORSHeaders sets CORS headers for OAuth endpoints to allow MCP Inspector access +func (em *EndpointManager) setCORSHeaders(w http.ResponseWriter) { + origin := "*" // TODO: Restrict to specific origins + + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, mcp-protocol-version") + w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours + w.Header().Set("Access-Control-Allow-Credentials", "false") // Explicit false for wildcard origin +} + +// setCacheHeaders sets cache control headers based on EnableCache configuration +func (em *EndpointManager) setCacheHeaders(w http.ResponseWriter) { + if config.EnableCache { + // Enable caching for 1 hour when cache is enabled + w.Header().Set("Cache-Control", "max-age=3600") + } else { + // Disable all caching when cache is disabled (for debugging) + w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") + w.Header().Set("Pragma", "no-cache") + w.Header().Set("Expires", "0") + } +} + +// RegisterEndpoints registers OAuth endpoints with the provided HTTP mux +func (em *EndpointManager) RegisterEndpoints(mux *http.ServeMux) { + // OAuth 2.0 Protected Resource Metadata endpoint (RFC 9728) + mux.HandleFunc("/.well-known/oauth-protected-resource", em.protectedResourceMetadataHandler()) + + // OAuth 2.0 Authorization Server Metadata endpoint (RFC 8414) + // Note: This would typically be served by Azure AD, but we provide a proxy for convenience + mux.HandleFunc("/.well-known/oauth-authorization-server", em.authServerMetadataProxyHandler()) + + // OpenID Connect Discovery endpoint (compatibility with MCP Inspector) + mux.HandleFunc("/.well-known/openid-configuration", em.authServerMetadataProxyHandler()) + + // Authorization endpoint proxy to handle Azure AD compatibility + mux.HandleFunc("/oauth2/v2.0/authorize", em.authorizationProxyHandler()) + + // Dynamic Client Registration endpoint (RFC 7591) + mux.HandleFunc("/oauth/register", em.clientRegistrationHandler()) + + // Token introspection endpoint (RFC 7662) - optional + mux.HandleFunc("/oauth/introspect", em.tokenIntrospectionHandler()) + + // OAuth 2.0 callback endpoint for Authorization Code flow + mux.HandleFunc("/oauth/callback", em.callbackHandler()) + + // OAuth 2.0 token endpoint for Authorization Code exchange + mux.HandleFunc("/oauth2/v2.0/token", em.tokenHandler()) + + // Health check endpoint (unauthenticated) + mux.HandleFunc("/health", em.healthHandler()) +} + +// authServerMetadataProxyHandler proxies authorization server metadata from Azure AD +func (em *EndpointManager) authServerMetadataProxyHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log.Printf("OAuth DEBUG: Received request for authorization server metadata: %s %s", r.Method, r.URL.Path) + + // Set CORS headers for all requests + em.setCORSHeaders(w) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodGet { + log.Printf("OAuth ERROR: Invalid method %s for metadata endpoint", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Get metadata from Azure AD + provider := em.provider + + // Build server URL based on the request + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + // Use the Host header from the request + host := r.Host + if host == "" { + host = r.URL.Host + } + + serverURL := fmt.Sprintf("%s://%s", scheme, host) + + metadata, err := provider.GetAuthorizationServerMetadata(serverURL) + if err != nil { + log.Printf("Failed to fetch authorization server metadata: %v\n", err) + http.Error(w, fmt.Sprintf("Failed to fetch authorization server metadata: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + em.setCacheHeaders(w) + + if err := json.NewEncoder(w).Encode(metadata); err != nil { + log.Printf("Failed to encode response: %v\n", err) + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } +} + +// clientRegistrationHandler implements OAuth 2.0 Dynamic Client Registration (RFC 7591) +func (em *EndpointManager) clientRegistrationHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log.Printf("OAuth DEBUG: Received client registration request: %s %s", r.Method, r.URL.Path) + + // Set CORS headers for all requests + em.setCORSHeaders(w) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodPost { + log.Printf("OAuth ERROR: Invalid method %s for client registration endpoint, only POST allowed", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse client registration request + var registrationRequest ClientRegistrationRequest + + if err := json.NewDecoder(r.Body).Decode(®istrationRequest); err != nil { + log.Printf("OAuth ERROR: Failed to parse client registration JSON: %v", err) + em.writeErrorResponse(w, "invalid_request", "Invalid JSON in request body", http.StatusBadRequest) + return + } + + log.Printf("OAuth DEBUG: Client registration request parsed - client_name: %s, redirect_uris: %v", registrationRequest.ClientName, registrationRequest.RedirectURIs) + + // Validate registration request + if err := em.validateClientRegistration(®istrationRequest); err != nil { + log.Printf("OAuth ERROR: Client registration validation failed: %v", err) + em.writeErrorResponse(w, "invalid_client_metadata", err.Error(), http.StatusBadRequest) + return + } + + // Use client-requested grant types if provided and valid, otherwise use defaults + grantTypes := registrationRequest.GrantTypes + if len(grantTypes) == 0 { + grantTypes = []string{"authorization_code", "refresh_token"} + } + + // Use client-requested response types if provided and valid, otherwise use defaults + responseTypes := registrationRequest.ResponseTypes + if len(responseTypes) == 0 { + responseTypes = []string{"code"} + } + + // For Azure AD compatibility, use the configured client ID + // In a full RFC 7591 implementation, each registration would get a unique ID + // But since Azure AD requires pre-registered client IDs, we return the configured one + clientID := em.config.ClientID + + log.Printf("OAuth DEBUG: Client registration successful - returning client_id: %s", clientID) + + clientInfo := map[string]interface{}{ + "client_id": clientID, // Use configured Azure AD client ID + "client_id_issued_at": time.Now().Unix(), // RFC 7591: timestamp of issuance + "redirect_uris": registrationRequest.RedirectURIs, + "token_endpoint_auth_method": "none", // Public client (PKCE required) + "grant_types": grantTypes, + "response_types": responseTypes, + "client_name": registrationRequest.ClientName, + "client_uri": registrationRequest.ClientURI, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + + if err := json.NewEncoder(w).Encode(clientInfo); err != nil { + log.Printf("OAuth ERROR: Failed to encode client registration response: %v", err) + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } +} + +// validateClientRegistration validates a client registration request +func (em *EndpointManager) validateClientRegistration(req *ClientRegistrationRequest) error { + // Validate redirect URIs - require at least one + if len(req.RedirectURIs) == 0 { + return fmt.Errorf("at least one redirect_uri is required") + } + + // Basic URL validation for redirect URIs + for _, redirectURI := range req.RedirectURIs { + if _, err := url.Parse(redirectURI); err != nil { + return fmt.Errorf("invalid redirect_uri format: %s", redirectURI) + } + } + + // Validate grant types + validGrantTypes := map[string]bool{ + "authorization_code": true, + "refresh_token": true, + } + + for _, grantType := range req.GrantTypes { + if !validGrantTypes[grantType] { + return fmt.Errorf("unsupported grant_type: %s", grantType) + } + } + + // Validate response types + validResponseTypes := map[string]bool{ + "code": true, + } + + for _, responseType := range req.ResponseTypes { + if !validResponseTypes[responseType] { + return fmt.Errorf("unsupported response_type: %s", responseType) + } + } + + return nil +} + +// tokenIntrospectionHandler implements RFC 7662 OAuth 2.0 Token Introspection +func (em *EndpointManager) tokenIntrospectionHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Set CORS headers for all requests + em.setCORSHeaders(w) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // This endpoint should be protected with client authentication + // For simplicity, we'll skip client auth in this implementation + + token := r.FormValue("token") + if token == "" { + em.writeErrorResponse(w, "invalid_request", "Missing token parameter", http.StatusBadRequest) + return + } + + // Validate the token + provider := em.provider + + tokenInfo, err := provider.ValidateToken(r.Context(), token) + if err != nil { + // Return inactive token response + response := map[string]interface{}{ + "active": false, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + return + } + + // Return active token response + response := map[string]interface{}{ + "active": true, + "client_id": em.config.ClientID, + "scope": strings.Join(tokenInfo.Scope, " "), + "sub": tokenInfo.Subject, + "aud": tokenInfo.Audience, + "iss": tokenInfo.Issuer, + "exp": tokenInfo.ExpiresAt.Unix(), + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } +} + +// healthHandler provides a simple health check endpoint +func (em *EndpointManager) healthHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Set CORS headers for all requests + em.setCORSHeaders(w) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + response := map[string]interface{}{ + "status": "healthy", + "oauth": map[string]interface{}{ + "enabled": em.config.Enabled, + }, + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } +} + +// protectedResourceMetadataHandler handles OAuth 2.0 Protected Resource Metadata requests +func (em *EndpointManager) protectedResourceMetadataHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log.Printf("OAuth DEBUG: Received request for protected resource metadata: %s %s", r.Method, r.URL.Path) + + // Set CORS headers for all requests + em.setCORSHeaders(w) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodGet { + log.Printf("OAuth ERROR: Invalid method %s for protected resource metadata endpoint", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Build resource URL based on the request + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + // Use the Host header from the request + host := r.Host + if host == "" { + host = r.URL.Host + } + + // Build the resource URL + resourceURL := fmt.Sprintf("%s://%s", scheme, host) + log.Printf("OAuth DEBUG: Building protected resource metadata for URL: %s", resourceURL) + + provider := em.provider + + metadata, err := provider.GetProtectedResourceMetadata(resourceURL) + if err != nil { + log.Printf("OAuth ERROR: Failed to get protected resource metadata: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + log.Printf("OAuth DEBUG: Successfully generated protected resource metadata with %d authorization servers", len(metadata.AuthorizationServers)) + + w.Header().Set("Content-Type", "application/json") + em.setCacheHeaders(w) + + if err := json.NewEncoder(w).Encode(metadata); err != nil { + log.Printf("OAuth ERROR: Failed to encode protected resource metadata response: %v", err) + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } +} + +// writeErrorResponse writes an OAuth error response +func (em *EndpointManager) writeErrorResponse(w http.ResponseWriter, errorCode, description string, statusCode int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + + response := map[string]interface{}{ + "error": errorCode, + "error_description": description, + } + + json.NewEncoder(w).Encode(response) +} + +// authorizationProxyHandler proxies authorization requests to Azure AD with resource parameter filtering +func (em *EndpointManager) authorizationProxyHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log.Printf("OAuth DEBUG: Received authorization proxy request: %s %s", r.Method, r.URL.Path) + + // Set CORS headers for all requests + em.setCORSHeaders(w) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodGet { + log.Printf("OAuth ERROR: Invalid method %s for authorization endpoint, only GET allowed", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse query parameters + query := r.URL.Query() + + // Enforce PKCE for OAuth 2.1 compliance (MCP requirement) + codeChallenge := query.Get("code_challenge") + codeChallengeMethod := query.Get("code_challenge_method") + + if codeChallenge == "" { + log.Printf("OAuth ERROR: Missing PKCE code_challenge parameter (required for OAuth 2.1)") + em.writeErrorResponse(w, "invalid_request", "PKCE code_challenge is required", http.StatusBadRequest) + return + } + + if codeChallengeMethod == "" { + // Default to S256 if not specified + query.Set("code_challenge_method", "S256") + log.Printf("OAuth DEBUG: Setting default code_challenge_method to S256") + } else if codeChallengeMethod != "S256" { + log.Printf("OAuth ERROR: Unsupported code_challenge_method: %s (only S256 supported)", codeChallengeMethod) + em.writeErrorResponse(w, "invalid_request", "Only S256 code_challenge_method is supported", http.StatusBadRequest) + return + } + + // Resource parameter handling for MCP compliance + // requestedScopes := strings.Split(query.Get("scope"), " ") + + // Azure AD v2.0 doesn't support RFC 8707 Resource Indicators in authorization requests + // Remove the resource parameter if present for Azure AD compatibility + resourceParam := query.Get("resource") + if resourceParam != "" { + log.Printf("OAuth DEBUG: Removing resource parameter for Azure AD compatibility: %s", resourceParam) + query.Del("resource") + } + + // Use only server-required scopes for Azure AD compatibility + // Azure AD .default scopes cannot be mixed with OpenID Connect scopes + // We prioritize Azure Management API access over OpenID Connect user info + finalScopes := em.config.RequiredScopes + + finalScopeString := strings.Join(finalScopes, " ") + query.Set("scope", finalScopeString) + log.Printf("OAuth DEBUG: Setting final scope for Azure AD: %s", finalScopeString) + + // Build the Azure AD authorization URL + azureAuthURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", em.config.TenantID) + + // Create the redirect URL with filtered parameters + redirectURL := fmt.Sprintf("%s?%s", azureAuthURL, query.Encode()) + log.Printf("OAuth DEBUG: Redirecting to Azure AD authorization endpoint: %s", azureAuthURL) + + // Redirect to Azure AD + http.Redirect(w, r, redirectURL, http.StatusFound) + } +} + +// callbackHandler handles OAuth 2.0 Authorization Code flow callback +func (em *EndpointManager) callbackHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log.Printf("OAuth DEBUG: Received callback request: %s %s", r.Method, r.URL.Path) + + // Set CORS headers for all requests + em.setCORSHeaders(w) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodGet { + log.Printf("OAuth ERROR: Invalid method %s for callback endpoint, only GET allowed", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse query parameters + query := r.URL.Query() + + // Check for error response from authorization server + if authError := query.Get("error"); authError != "" { + errorDesc := query.Get("error_description") + log.Printf("OAuth ERROR: Authorization server returned error: %s - %s", authError, errorDesc) + em.writeCallbackErrorResponse(w, fmt.Sprintf("Authorization failed: %s - %s", authError, errorDesc)) + return + } + + // Get authorization code + code := query.Get("code") + if code == "" { + log.Printf("OAuth ERROR: Missing authorization code in callback") + em.writeCallbackErrorResponse(w, "Missing authorization code") + return + } + + // Get state parameter for CSRF protection + state := query.Get("state") + if state == "" { + log.Printf("OAuth ERROR: Missing state parameter in callback") + em.writeCallbackErrorResponse(w, "Missing state parameter") + return + } + + log.Printf("OAuth DEBUG: Callback parameters validated - has_code: true, state: %s", state) + + // Exchange authorization code for access token + tokenResponse, err := em.exchangeCodeForToken(code, state) + if err != nil { + log.Printf("OAuth ERROR: Failed to exchange authorization code for token: %v", err) + em.writeCallbackErrorResponse(w, fmt.Sprintf("Failed to exchange code for token: %v", err)) + return + } + + log.Printf("OAuth DEBUG: Token exchange successful, preparing callback success response") + + // Skip token validation in callback - validation happens during MCP requests + // Create minimal token info for callback success page + tokenInfo := &auth.TokenInfo{ + AccessToken: tokenResponse.AccessToken, + TokenType: "Bearer", + ExpiresAt: time.Now().Add(time.Hour), // Default 1 hour expiration + Scope: em.config.RequiredScopes, // Use configured scopes + Subject: "authenticated_user", // Placeholder + Audience: []string{fmt.Sprintf("https://sts.windows.net/%s/", em.config.TenantID)}, + Issuer: fmt.Sprintf("https://sts.windows.net/%s/", em.config.TenantID), + Claims: make(map[string]interface{}), + } + + // Return success response with token information + em.writeCallbackSuccessResponse(w, tokenResponse, tokenInfo) + } +} + +// TokenResponse represents the response from token exchange +type TokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` +} + +// exchangeCodeForToken exchanges authorization code for access token +func (em *EndpointManager) exchangeCodeForToken(code, state string) (*TokenResponse, error) { + // Prepare token exchange request + tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", em.config.TenantID) + + // Use default callback redirect URI for token exchange + redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", 8080) // Default for MCP servers + + // Prepare form data + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("client_id", em.config.ClientID) + data.Set("code", code) + data.Set("redirect_uri", redirectURI) + data.Set("scope", strings.Join(em.config.RequiredScopes, " ")) + + // Note: Azure AD v2.0 doesn't support the 'resource' parameter in token requests + // It uses scope-based resource identification instead + // For MCP compliance, we handle resource binding through audience validation + + // Make token exchange request + resp, err := http.PostForm(tokenURL, data) + if err != nil { + return nil, fmt.Errorf("token exchange request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse token response + var tokenResponse TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + return &tokenResponse, nil +} + +// writeCallbackErrorResponse writes an error response for callback +func (em *EndpointManager) writeCallbackErrorResponse(w http.ResponseWriter, message string) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + + html := fmt.Sprintf(` + + + + OAuth Authentication Error + + + +
+

Authentication Error

+

%s

+

Please try again or contact your administrator.

+
+ +`, message) + + w.Write([]byte(html)) +} + +// writeCallbackSuccessResponse writes a success response for callback +func (em *EndpointManager) writeCallbackSuccessResponse(w http.ResponseWriter, tokenResponse *TokenResponse, tokenInfo *auth.TokenInfo) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + + // Generate a secure session token for the client to use + _, err := em.generateSessionToken() + if err != nil { + em.writeCallbackErrorResponse(w, "Failed to generate session token") + return + } + + html := fmt.Sprintf(` + + + + OAuth Authentication Success + + + +
+

Authentication Successful

+

You have been successfully authenticated with Azure AD.

+ +
+

Access Token (use as Bearer token):

+
%s
+ +
+ +
+

Token Information:

+
    +
  • Subject: %s
  • +
  • Audience: %s
  • +
  • Scope: %s
  • +
  • Expires: %s
  • +
+
+ +
+

For MCP Client Usage:

+

Use this token in the Authorization header:

+
Authorization: Bearer %s
+ +
+
+ + + +`, + tokenResponse.AccessToken, + tokenInfo.Subject, + strings.Join(tokenInfo.Audience, ", "), + strings.Join(tokenInfo.Scope, ", "), + tokenInfo.ExpiresAt.Format("2006-01-02 15:04:05 UTC"), + tokenResponse.AccessToken, + tokenResponse.AccessToken) + + w.Write([]byte(html)) +} + +// isValidClientID validates if a client ID is acceptable +func (em *EndpointManager) isValidClientID(clientID string) bool { + // Accept configured client ID (primary method for Azure AD) + if clientID == em.config.ClientID { + return true + } + + // For future extensibility, could accept other registered client IDs + // But for Azure AD integration, we primarily use the configured client ID + + return false +} + +// generateSessionToken generates a secure random session token +func (em *EndpointManager) generateSessionToken() (string, error) { + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(bytes), nil +} + +// tokenHandler handles OAuth 2.0 token endpoint requests (Authorization Code exchange) +func (em *EndpointManager) tokenHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log.Printf("OAuth DEBUG: Received token endpoint request: %s %s", r.Method, r.URL.Path) + + // Set CORS headers for all requests + em.setCORSHeaders(w) + + // Handle preflight OPTIONS request + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodPost { + log.Printf("OAuth ERROR: Invalid method %s for token endpoint, only POST allowed", r.Method) + em.writeErrorResponse(w, "invalid_request", "Only POST method is allowed", http.StatusMethodNotAllowed) + return + } + + // Parse form data + if err := r.ParseForm(); err != nil { + log.Printf("OAuth ERROR: Failed to parse form data: %v", err) + em.writeErrorResponse(w, "invalid_request", "Failed to parse form data", http.StatusBadRequest) + return + } + + // Validate grant type + grantType := r.FormValue("grant_type") + if grantType != "authorization_code" { + log.Printf("OAuth ERROR: Unsupported grant type: %s", grantType) + em.writeErrorResponse(w, "unsupported_grant_type", fmt.Sprintf("Unsupported grant type: %s", grantType), http.StatusBadRequest) + return + } + + // Extract required parameters + code := r.FormValue("code") + clientID := r.FormValue("client_id") + redirectURI := r.FormValue("redirect_uri") + codeVerifier := r.FormValue("code_verifier") // PKCE parameter + + if code == "" { + log.Printf("OAuth ERROR: Missing authorization code in token request") + em.writeErrorResponse(w, "invalid_request", "Missing authorization code", http.StatusBadRequest) + return + } + + if clientID == "" { + log.Printf("OAuth ERROR: Missing client_id in token request") + em.writeErrorResponse(w, "invalid_request", "Missing client_id", http.StatusBadRequest) + return + } + + if redirectURI == "" { + log.Printf("OAuth ERROR: Missing redirect_uri in token request") + em.writeErrorResponse(w, "invalid_request", "Missing redirect_uri", http.StatusBadRequest) + return + } + + // Enforce PKCE code_verifier for OAuth 2.1 compliance + if codeVerifier == "" { + log.Printf("OAuth ERROR: Missing PKCE code_verifier (required for OAuth 2.1)") + em.writeErrorResponse(w, "invalid_request", "PKCE code_verifier is required", http.StatusBadRequest) + return + } + + log.Printf("OAuth DEBUG: Token request parameters validated - client_id: %s, redirect_uri: %s, has_code_verifier: %t", clientID, redirectURI, codeVerifier != "") + + // Validate client ID (accept both configured and dynamically registered clients) + if !em.isValidClientID(clientID) { + log.Printf("OAuth ERROR: Invalid client_id: %s", clientID) + em.writeErrorResponse(w, "invalid_client", "Invalid client_id", http.StatusBadRequest) + return + } + + // Extract scope from the token request (MCP client should send the same scope) + requestedScope := r.FormValue("scope") + if requestedScope == "" { + // Fallback to server required scopes if not provided + requestedScope = strings.Join(em.config.RequiredScopes, " ") + } + + log.Printf("OAuth DEBUG: Exchanging authorization code for access token with Azure AD, scope: %s", requestedScope) + + // Exchange authorization code for access token with Azure AD + tokenResponse, err := em.exchangeCodeForTokenDirect(code, redirectURI, codeVerifier, requestedScope) + if err != nil { + log.Printf("OAuth ERROR: Token exchange with Azure AD failed: %v", err) + em.writeErrorResponse(w, "invalid_grant", fmt.Sprintf("Authorization code exchange failed: %v", err), http.StatusBadRequest) + return + } + + log.Printf("OAuth DEBUG: Token exchange successful, returning token response to client") + + // Return token response + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + + if err := json.NewEncoder(w).Encode(tokenResponse); err != nil { + log.Printf("OAuth ERROR: Failed to encode token response: %v", err) + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } +} + +// exchangeCodeForTokenDirect exchanges authorization code for access token directly with Azure AD +func (em *EndpointManager) exchangeCodeForTokenDirect(code, redirectURI, codeVerifier, scope string) (*TokenResponse, error) { + // Prepare token exchange request to Azure AD + tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", em.config.TenantID) + + // Prepare form data + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("client_id", em.config.ClientID) + data.Set("code", code) + data.Set("redirect_uri", redirectURI) + data.Set("scope", scope) // Use the scope provided by the client + + // Add PKCE code_verifier if present + if codeVerifier != "" { + data.Set("code_verifier", codeVerifier) + log.Printf("Including PKCE code_verifier in Azure AD token request") + } else { + log.Printf("No PKCE code_verifier provided - this may cause PKCE verification to fail") + } + + // Note: Azure AD v2.0 doesn't support the 'resource' parameter in token requests + // It uses scope-based resource identification instead + // For MCP compliance, we handle resource binding through audience validation + log.Printf("Azure AD token request with scope: %s", scope) + + // Make token exchange request to Azure AD + resp, err := http.PostForm(tokenURL, data) + if err != nil { + return nil, fmt.Errorf("token exchange request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse token response + var tokenResponse TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + log.Printf("Token exchange successful: access_token received (length: %d)", len(tokenResponse.AccessToken)) + + return &tokenResponse, nil +} diff --git a/internal/auth/oauth/endpoints_test.go b/internal/auth/oauth/endpoints_test.go new file mode 100644 index 0000000..c2aa1f8 --- /dev/null +++ b/internal/auth/oauth/endpoints_test.go @@ -0,0 +1,436 @@ +package oauth + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Azure/aks-mcp/internal/auth" +) + +func TestEndpointManager_RegisterEndpoints(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: false, + ValidateAudience: false, + ExpectedAudience: "https://management.azure.com/", + }, + } + + provider, _ := NewAzureOAuthProvider(config) + manager := NewEndpointManager(provider, config) + + mux := http.NewServeMux() + manager.RegisterEndpoints(mux) + + // Test that endpoints are registered by making requests + testCases := []struct { + method string + path string + status int + }{ + {"GET", "/.well-known/oauth-protected-resource", http.StatusOK}, + {"GET", "/.well-known/oauth-authorization-server", http.StatusInternalServerError}, // Will fail without real Azure AD + {"POST", "/oauth/register", http.StatusBadRequest}, // Missing required data + {"POST", "/oauth/introspect", http.StatusBadRequest}, // Missing token param + {"GET", "/oauth/callback", http.StatusBadRequest}, // Missing required params + {"GET", "/health", http.StatusOK}, + } + + for _, tc := range testCases { + t.Run(tc.method+" "+tc.path, func(t *testing.T) { + req := httptest.NewRequest(tc.method, tc.path, nil) + w := httptest.NewRecorder() + + mux.ServeHTTP(w, req) + + if w.Code != tc.status { + t.Errorf("Expected status %d for %s %s, got %d", tc.status, tc.method, tc.path, w.Code) + } + }) + } +} + +func TestProtectedResourceMetadataEndpoint(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + } + + provider, _ := NewAzureOAuthProvider(config) + manager := NewEndpointManager(provider, config) + + req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil) + w := httptest.NewRecorder() + + handler := manager.protectedResourceMetadataHandler() + handler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + var metadata ProtectedResourceMetadata + if err := json.Unmarshal(w.Body.Bytes(), &metadata); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + expectedAuthServer := "http://example.com" + if len(metadata.AuthorizationServers) != 1 || metadata.AuthorizationServers[0] != expectedAuthServer { + t.Errorf("Expected auth server %s, got %v", expectedAuthServer, metadata.AuthorizationServers) + } + + if len(metadata.ScopesSupported) != 1 || metadata.ScopesSupported[0] != "https://management.azure.com/.default" { + t.Errorf("Expected scopes %v, got %v", config.RequiredScopes, metadata.ScopesSupported) + } +} + +func TestClientRegistrationEndpoint(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + } + + provider, _ := NewAzureOAuthProvider(config) + manager := NewEndpointManager(provider, config) + + // Test valid registration request + registrationRequest := map[string]interface{}{ + "redirect_uris": []string{"http://localhost:3000/callback"}, + "token_endpoint_auth_method": "none", + "grant_types": []string{"authorization_code"}, + "response_types": []string{"code"}, + "scope": "https://management.azure.com/.default", + "client_name": "Test Client", + } + + reqBody, _ := json.Marshal(registrationRequest) + req := httptest.NewRequest("POST", "/oauth/register", strings.NewReader(string(reqBody))) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + handler := manager.clientRegistrationHandler() + handler(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected status 201, got %d", w.Code) + } + + var response map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if response["client_id"] == "" { + t.Error("Expected client_id in response") + } + + redirectURIs, ok := response["redirect_uris"].([]interface{}) + if !ok || len(redirectURIs) != 1 { + t.Errorf("Expected redirect URIs in response") + } +} + +func TestTokenIntrospectionEndpoint(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: false, + ValidateAudience: false, + }, + } + + provider, _ := NewAzureOAuthProvider(config) + manager := NewEndpointManager(provider, config) + + // Test with valid token (since JWT validation is disabled, any token works) + // Note: Must use a token that looks like a JWT (has dots) to pass initial format checks + req := httptest.NewRequest("POST", "/oauth/introspect", strings.NewReader("token=header.payload.signature")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + w := httptest.NewRecorder() + handler := manager.tokenIntrospectionHandler() + handler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + var response map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if active, ok := response["active"].(bool); !ok || !active { + t.Error("Expected active token") + } +} + +func TestTokenIntrospectionEndpointMissingToken(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + } + + provider, _ := NewAzureOAuthProvider(config) + manager := NewEndpointManager(provider, config) + + // Test without token parameter + req := httptest.NewRequest("POST", "/oauth/introspect", strings.NewReader("")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + w := httptest.NewRecorder() + handler := manager.tokenIntrospectionHandler() + handler(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for missing token, got %d", w.Code) + } +} + +func TestHealthEndpoint(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + } + + provider, _ := NewAzureOAuthProvider(config) + manager := NewEndpointManager(provider, config) + + req := httptest.NewRequest("GET", "/health", nil) + w := httptest.NewRecorder() + + handler := manager.healthHandler() + handler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + var response map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if response["status"] != "healthy" { + t.Errorf("Expected status healthy, got %v", response["status"]) + } + + oauth, ok := response["oauth"].(map[string]interface{}) + if !ok { + t.Error("Expected oauth object in response") + } + + if oauth["enabled"] != true { + t.Errorf("Expected oauth enabled true, got %v", oauth["enabled"]) + } +} + +func TestValidateClientRegistration(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + } + + provider, _ := NewAzureOAuthProvider(config) + manager := NewEndpointManager(provider, config) + + tests := []struct { + name string + request map[string]interface{} + wantErr bool + }{ + { + name: "valid request", + request: map[string]interface{}{ + "redirect_uris": []string{"http://localhost:3000/callback"}, + "grant_types": []string{"authorization_code"}, + "response_types": []string{"code"}, + }, + wantErr: false, + }, + { + name: "missing redirect URIs", + request: map[string]interface{}{ + "grant_types": []string{"authorization_code"}, + "response_types": []string{"code"}, + }, + wantErr: true, + }, + { + name: "invalid grant type", + request: map[string]interface{}{ + "redirect_uris": []string{"http://localhost:3000/callback"}, + "grant_types": []string{"client_credentials"}, + "response_types": []string{"code"}, + }, + wantErr: true, + }, + { + name: "invalid response type", + request: map[string]interface{}{ + "redirect_uris": []string{"http://localhost:3000/callback"}, + "grant_types": []string{"authorization_code"}, + "response_types": []string{"token"}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Convert test request to the expected struct format + req := &ClientRegistrationRequest{} + + if redirectURIs, ok := tt.request["redirect_uris"].([]string); ok { + req.RedirectURIs = redirectURIs + } + if grantTypes, ok := tt.request["grant_types"].([]string); ok { + req.GrantTypes = grantTypes + } + if responseTypes, ok := tt.request["response_types"].([]string); ok { + req.ResponseTypes = responseTypes + } + + err := manager.validateClientRegistration(req) + if (err != nil) != tt.wantErr { + t.Errorf("validateClientRegistration() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestCallbackEndpointMissingCode(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + } + + provider, _ := NewAzureOAuthProvider(config) + manager := NewEndpointManager(provider, config) + + // Test callback without authorization code + req := httptest.NewRequest("GET", "/oauth/callback?state=test-state", nil) + w := httptest.NewRecorder() + + handler := manager.callbackHandler() + handler(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for missing code, got %d", w.Code) + } + + // Check that response contains HTML error page + contentType := w.Header().Get("Content-Type") + if !strings.Contains(contentType, "text/html") { + t.Errorf("Expected HTML content type, got %s", contentType) + } + + body := w.Body.String() + if !strings.Contains(body, "Missing authorization code") { + t.Error("Expected error message about missing authorization code") + } +} + +func TestCallbackEndpointMissingState(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + } + + provider, _ := NewAzureOAuthProvider(config) + manager := NewEndpointManager(provider, config) + + // Test callback without state parameter + req := httptest.NewRequest("GET", "/oauth/callback?code=test-code", nil) + w := httptest.NewRecorder() + + handler := manager.callbackHandler() + handler(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for missing state, got %d", w.Code) + } + + body := w.Body.String() + if !strings.Contains(body, "Missing state parameter") { + t.Error("Expected error message about missing state parameter") + } +} + +func TestCallbackEndpointAuthError(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + } + + provider, _ := NewAzureOAuthProvider(config) + manager := NewEndpointManager(provider, config) + + // Test callback with authorization error + req := httptest.NewRequest("GET", "/oauth/callback?error=access_denied&error_description=User%20denied%20access", nil) + w := httptest.NewRecorder() + + handler := manager.callbackHandler() + handler(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for auth error, got %d", w.Code) + } + + body := w.Body.String() + if !strings.Contains(body, "Authorization failed") { + t.Error("Expected error message about authorization failure") + } + if !strings.Contains(body, "access_denied") { + t.Error("Expected specific error code in response") + } +} + +func TestCallbackEndpointMethodNotAllowed(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + } + + provider, _ := NewAzureOAuthProvider(config) + manager := NewEndpointManager(provider, config) + + // Test callback with POST method (should only accept GET) + req := httptest.NewRequest("POST", "/oauth/callback", nil) + w := httptest.NewRecorder() + + handler := manager.callbackHandler() + handler(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("Expected status 405 for POST method, got %d", w.Code) + } +} diff --git a/internal/auth/oauth/middleware.go b/internal/auth/oauth/middleware.go new file mode 100644 index 0000000..f951fef --- /dev/null +++ b/internal/auth/oauth/middleware.go @@ -0,0 +1,285 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "strings" + + "github.com/Azure/aks-mcp/internal/auth" +) + +// contextKey is a custom type for context keys to avoid collisions +type contextKey string + +const tokenInfoKey contextKey = "token_info" + +// AuthMiddleware handles OAuth authentication for HTTP requests +type AuthMiddleware struct { + provider *AzureOAuthProvider + serverURL string +} + +// setCORSHeaders sets CORS headers for OAuth endpoints to allow MCP Inspector access +func (m *AuthMiddleware) setCORSHeaders(w http.ResponseWriter) { + origin := "*" // TODO: Restrict to specific origins + + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, mcp-protocol-version") + w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours + w.Header().Set("Access-Control-Allow-Credentials", "false") // Explicit false for wildcard origin +} + +// NewAuthMiddleware creates a new authentication middleware +func NewAuthMiddleware(provider *AzureOAuthProvider, serverURL string) *AuthMiddleware { + return &AuthMiddleware{ + provider: provider, + serverURL: serverURL, + } +} + +// Middleware returns an HTTP middleware function for OAuth authentication +func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + // Skip authentication for specific endpoints + if m.shouldSkipAuth(r) { + log.Printf("Skipping auth for path: %s\n", r.URL.Path) + next.ServeHTTP(w, r) + return + } + + // Perform authentication + authResult := m.authenticateRequest(r) + + if !authResult.Authenticated { + log.Printf("Authentication FAILED - handling error\n") + m.handleAuthError(w, r, authResult) + return + } + + // Add token info to request context + ctx := context.WithValue(r.Context(), tokenInfoKey, authResult.TokenInfo) + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + }) +} + +// shouldSkipAuth determines if authentication should be skipped for this request +func (m *AuthMiddleware) shouldSkipAuth(r *http.Request) bool { + // Skip auth for OAuth metadata endpoints + path := r.URL.Path + + skipPaths := []string{ + "/.well-known/oauth-protected-resource", + "/.well-known/oauth-authorization-server", + "/.well-known/openid-configuration", + "/oauth2/v2.0/authorize", + "/oauth/register", + "/oauth/callback", + "/oauth2/v2.0/token", + "/oauth/introspect", + "/health", + "/ping", + } + + for _, skipPath := range skipPaths { + if path == skipPath { + return true + } + } + + return false +} + +// authenticateRequest performs OAuth authentication on the request +func (m *AuthMiddleware) authenticateRequest(r *http.Request) *auth.AuthResult { + // Extract Bearer token from Authorization header + authHeader := r.Header.Get("Authorization") + + if authHeader == "" { + log.Printf("OAuth DEBUG - Missing authorization header for %s %s\n", r.Method, r.URL.Path) + log.Printf("OAuth DEBUG - Request headers: %+v\n", r.Header) + return &auth.AuthResult{ + Authenticated: false, + Error: "missing authorization header", + StatusCode: http.StatusUnauthorized, + } + } + + // Check for Bearer token format + const bearerPrefix = "Bearer " + if !strings.HasPrefix(authHeader, bearerPrefix) { + log.Printf("FAILED - Invalid authorization header format (missing Bearer prefix)\n") + return &auth.AuthResult{ + Authenticated: false, + Error: "invalid authorization header format", + StatusCode: http.StatusUnauthorized, + } + } + + token := strings.TrimPrefix(authHeader, bearerPrefix) + if token == "" { + log.Printf("FAILED - Empty bearer token\n") + return &auth.AuthResult{ + Authenticated: false, + Error: "empty bearer token", + StatusCode: http.StatusUnauthorized, + } + } + + // Basic JWT structure validation + tokenParts := strings.Split(token, ".") + if len(tokenParts) != 3 { + log.Printf("FAILED - JWT structure validation (has %d parts, expected 3)\n", len(tokenParts)) + return &auth.AuthResult{ + Authenticated: false, + Error: "invalid JWT structure", + StatusCode: http.StatusUnauthorized, + } + } + + // Validate the token + tokenInfo, err := m.provider.ValidateToken(r.Context(), token) + if err != nil { + log.Printf("FAILED - Provider token validation failed: %v\n", err) + return &auth.AuthResult{ + Authenticated: false, + Error: fmt.Sprintf("token validation failed: %v", err), + StatusCode: http.StatusUnauthorized, + } + } + + // Validate required scopes - strict enforcement for security + if !m.validateScopes(tokenInfo.Scope) { + log.Printf("SCOPE ERROR: Token scopes %v don't match required scopes %v", tokenInfo.Scope, m.provider.config.RequiredScopes) + return &auth.AuthResult{ + Authenticated: false, + Error: "insufficient scope", + StatusCode: http.StatusForbidden, + } + } + + return &auth.AuthResult{ + Authenticated: true, + TokenInfo: tokenInfo, + StatusCode: http.StatusOK, + } +} + +// validateScopes checks if the token has required scopes +func (m *AuthMiddleware) validateScopes(tokenScopes []string) bool { + requiredScopes := m.provider.config.RequiredScopes + if len(requiredScopes) == 0 { + return true // No scopes required + } + + // Check if token has at least one required scope + for _, required := range requiredScopes { + if m.hasScopePermission(required, tokenScopes) { + return true + } + } + + return false +} + +// hasScopePermission checks if the token scopes satisfy the required scope +func (m *AuthMiddleware) hasScopePermission(requiredScope string, tokenScopes []string) bool { + // Direct scope match + for _, tokenScope := range tokenScopes { + if tokenScope == requiredScope { + return true + } + } + + // Azure resource scope mapping + azureResourceMappings := map[string][]string{ + "https://management.azure.com/.default": { + "user_impersonation", + "https://management.azure.com/user_impersonation", + "https://management.azure.com/.default", + "https://management.core.windows.net/", + "https://management.azure.com/", + }, + "https://graph.microsoft.com/.default": { + "User.Read", + "https://graph.microsoft.com/User.Read", + }, + } + + if allowedScopes, exists := azureResourceMappings[requiredScope]; exists { + for _, allowedScope := range allowedScopes { + for _, tokenScope := range tokenScopes { + if tokenScope == allowedScope { + return true + } + } + } + } + + return false +} + +// handleAuthError handles authentication errors +func (m *AuthMiddleware) handleAuthError(w http.ResponseWriter, r *http.Request, authResult *auth.AuthResult) { + // Set CORS headers + m.setCORSHeaders(w) + w.Header().Set("Content-Type", "application/json") + + // Add WWW-Authenticate header for 401 responses (RFC 9728 Section 5.1) + if authResult.StatusCode == http.StatusUnauthorized { + // Build the resource metadata URL + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + host := r.Host + if host == "" { + host = r.URL.Host + } + serverURL := fmt.Sprintf("%s://%s", scheme, host) + resourceMetadataURL := fmt.Sprintf("%s/.well-known/oauth-protected-resource", serverURL) + + // RFC 9728 compliant WWW-Authenticate header + wwwAuth := fmt.Sprintf(`Bearer realm="%s", resource_metadata="%s"`, serverURL, resourceMetadataURL) + + // Add error information if available + if authResult.Error != "" { + wwwAuth += fmt.Sprintf(`, error="invalid_token", error_description="%s"`, authResult.Error) + } + + w.Header().Set("WWW-Authenticate", wwwAuth) + } + + w.WriteHeader(authResult.StatusCode) + + errorResponse := map[string]interface{}{ + "error": getOAuthErrorCode(authResult.StatusCode), + "error_description": authResult.Error, + } + + if err := json.NewEncoder(w).Encode(errorResponse); err != nil { + log.Printf("MIDDLEWARE ERROR: Failed to encode error response: %v\n", err) + } else { + log.Printf("MIDDLEWARE ERROR: Error response sent\n") + } +} + +// getOAuthErrorCode returns appropriate OAuth error code for HTTP status +func getOAuthErrorCode(statusCode int) string { + switch statusCode { + case http.StatusUnauthorized: + return "invalid_token" + case http.StatusForbidden: + return "insufficient_scope" + case http.StatusBadRequest: + return "invalid_request" + default: + return "server_error" + } +} diff --git a/internal/auth/oauth/middleware_test.go b/internal/auth/oauth/middleware_test.go new file mode 100644 index 0000000..ed73b4a --- /dev/null +++ b/internal/auth/oauth/middleware_test.go @@ -0,0 +1,251 @@ +package oauth + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Azure/aks-mcp/internal/auth" +) + +// GetTokenInfo extracts token information from request context (test helper) +func GetTokenInfo(r *http.Request) (*auth.TokenInfo, bool) { + tokenInfo, ok := r.Context().Value(tokenInfoKey).(*auth.TokenInfo) + return tokenInfo, ok +} + +func TestAuthMiddleware(t *testing.T) { + // Create test config with minimal required scopes for testing + // Note: We cannot test with empty RequiredScopes because the OAuth configuration + // validation now requires at least one scope to be specified. This is intentional + // to prevent security misconfigurations in production environments. + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: false, + ValidateAudience: false, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + middleware := NewAuthMiddleware(provider, "http://localhost:8000") + + // Create a test handler + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + }) + + wrappedHandler := middleware.Middleware(testHandler) + + tests := []struct { + name string + authHeader string + expectedStatus int + path string + }{ + { + name: "valid bearer token", + authHeader: "Bearer header.payload.signature", + expectedStatus: http.StatusOK, + path: "/test", + }, + { + name: "missing authorization header", + authHeader: "", + expectedStatus: http.StatusUnauthorized, + path: "/test", + }, + { + name: "invalid token format", + authHeader: "InvalidFormat", + expectedStatus: http.StatusUnauthorized, + path: "/test", + }, + { + name: "non-bearer token", + authHeader: "Basic dXNlcjpwYXNz", + expectedStatus: http.StatusUnauthorized, + path: "/test", + }, + { + name: "skip auth for metadata endpoint", + authHeader: "", + expectedStatus: http.StatusOK, + path: "/.well-known/oauth-protected-resource", + }, + { + name: "skip auth for health endpoint", + authHeader: "", + expectedStatus: http.StatusOK, + path: "/health", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tt.path, nil) + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + if w.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) + } + + // Check WWW-Authenticate header for 401 responses + if w.Code == http.StatusUnauthorized { + wwwAuth := w.Header().Get("WWW-Authenticate") + if wwwAuth == "" { + t.Error("Expected WWW-Authenticate header for 401 response") + } + } + }) + } +} + +func TestAuthMiddlewareContextPropagation(t *testing.T) { + // Note: We cannot test with empty RequiredScopes because the OAuth configuration + // validation now requires at least one scope to be specified. + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: false, + ValidateAudience: false, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + middleware := NewAuthMiddleware(provider, "http://localhost:8000") + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check if token info is available in context + tokenInfo, ok := GetTokenInfo(r) + if !ok { + t.Error("Token info not found in context") + return + } + + if tokenInfo.AccessToken != "header.payload.signature" { + t.Errorf("Expected token header.payload.signature, got %s", tokenInfo.AccessToken) + } + + w.WriteHeader(http.StatusOK) + }) + + wrappedHandler := middleware.Middleware(testHandler) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer header.payload.signature") + + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } +} + +func TestShouldSkipAuth(t *testing.T) { + // Note: We cannot test with empty RequiredScopes because the OAuth configuration + // validation now requires at least one scope to be specified. + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, // Minimal scope for testing + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: false, + ValidateAudience: false, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + middleware := NewAuthMiddleware(provider, "http://localhost:8000") + + tests := []struct { + path string + expected bool + }{ + {"/.well-known/oauth-protected-resource", true}, + {"/.well-known/oauth-authorization-server", true}, + {"/.well-known/openid-configuration", true}, + {"/oauth2/v2.0/authorize", true}, + {"/oauth/register", true}, + {"/oauth/callback", true}, + {"/oauth2/v2.0/token", true}, + {"/oauth/introspect", true}, + {"/health", true}, + {"/ping", true}, + {"/test", false}, + {"/mcp", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + req := httptest.NewRequest("GET", tt.path, nil) + result := middleware.shouldSkipAuth(req) + if result != tt.expected { + t.Errorf("Expected %v for path %s, got %v", tt.expected, tt.path, result) + } + }) + } +} + +func TestGetTokenInfo(t *testing.T) { + // Test with valid token info + tokenInfo := &auth.TokenInfo{ + AccessToken: "test-token", + TokenType: "Bearer", + Subject: "user123", + } + + ctx := context.WithValue(context.Background(), tokenInfoKey, tokenInfo) + req := httptest.NewRequest("GET", "/test", nil) + req = req.WithContext(ctx) + + retrievedTokenInfo, ok := GetTokenInfo(req) + if !ok { + t.Error("Expected to find token info in context") + } + + if retrievedTokenInfo.AccessToken != "test-token" { + t.Errorf("Expected access token test-token, got %s", retrievedTokenInfo.AccessToken) + } + + // Test without token info + req = httptest.NewRequest("GET", "/test", nil) + _, ok = GetTokenInfo(req) + if ok { + t.Error("Expected not to find token info in empty context") + } +} diff --git a/internal/auth/oauth/provider.go b/internal/auth/oauth/provider.go new file mode 100644 index 0000000..cce007e --- /dev/null +++ b/internal/auth/oauth/provider.go @@ -0,0 +1,510 @@ +package oauth + +import ( + "context" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "math/big" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/Azure/aks-mcp/internal/auth" + internalConfig "github.com/Azure/aks-mcp/internal/config" + "github.com/golang-jwt/jwt/v5" +) + +// AzureOAuthProvider implements OAuth authentication for Azure AD +type AzureOAuthProvider struct { + config *auth.OAuthConfig + httpClient *http.Client + keyCache *keyCache + enableCache bool +} + +// keyCache caches Azure AD signing keys +type keyCache struct { + keys map[string]*rsa.PublicKey + expiresAt time.Time + mu sync.RWMutex +} + +// AzureADMetadata represents Azure AD OAuth metadata +type AzureADMetadata struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + JWKSUri string `json:"jwks_uri"` + ScopesSupported []string `json:"scopes_supported"` + ResponseTypesSupported []string `json:"response_types_supported"` + GrantTypesSupported []string `json:"grant_types_supported"` + SubjectTypesSupported []string `json:"subject_types_supported"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` +} + +// ProtectedResourceMetadata represents MCP protected resource metadata (RFC 9728 compliant) +type ProtectedResourceMetadata struct { + AuthorizationServers []string `json:"authorization_servers"` + Resource string `json:"resource"` + ScopesSupported []string `json:"scopes_supported"` +} + +// ClientRegistrationRequest represents OAuth 2.0 Dynamic Client Registration request (RFC 7591) +type ClientRegistrationRequest struct { + RedirectURIs []string `json:"redirect_uris"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + GrantTypes []string `json:"grant_types"` + ResponseTypes []string `json:"response_types"` + ClientName string `json:"client_name"` + ClientURI string `json:"client_uri"` + Scope string `json:"scope"` +} + +// NewAzureOAuthProvider creates a new Azure OAuth provider +func NewAzureOAuthProvider(config *auth.OAuthConfig) (*AzureOAuthProvider, error) { + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("invalid OAuth config: %w", err) + } + + return &AzureOAuthProvider{ + config: config, + enableCache: internalConfig.EnableCache, // Use config constant for cache control + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + keyCache: &keyCache{ + keys: make(map[string]*rsa.PublicKey), + }, + }, nil +} + +// GetProtectedResourceMetadata returns OAuth 2.0 Protected Resource Metadata (RFC 9728) +func (p *AzureOAuthProvider) GetProtectedResourceMetadata(serverURL string) (*ProtectedResourceMetadata, error) { + // For MCP compliance, point to our local authorization server proxy + // which properly advertises PKCE support + parsedURL, err := url.Parse(serverURL) + if err != nil { + return nil, fmt.Errorf("invalid server URL: %v", err) + } + + // Use the same scheme and host as the server URL + authServerURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host) + + // RFC 9728 requires the resource field to identify this MCP server + return &ProtectedResourceMetadata{ + AuthorizationServers: []string{authServerURL}, + Resource: serverURL, // Required by MCP spec + ScopesSupported: p.config.RequiredScopes, + }, nil +} + +// GetAuthorizationServerMetadata returns OAuth 2.0 Authorization Server Metadata (RFC 8414) +func (p *AzureOAuthProvider) GetAuthorizationServerMetadata(serverURL string) (*AzureADMetadata, error) { + metadataURL := fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0/.well-known/openid-configuration", p.config.TenantID) + log.Printf("OAuth DEBUG: Fetching Azure AD metadata from: %s", metadataURL) + + resp, err := p.httpClient.Get(metadataURL) + if err != nil { + log.Printf("OAuth ERROR: Failed to fetch metadata from %s: %v", metadataURL, err) + return nil, fmt.Errorf("failed to fetch metadata from %s: %w", metadataURL, err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + log.Printf("OAuth ERROR: Tenant ID '%s' not found (HTTP 404)", p.config.TenantID) + return nil, fmt.Errorf("tenant ID '%s' not found (HTTP 404). Please verify your Azure AD tenant ID is correct", p.config.TenantID) + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + log.Printf("OAuth ERROR: Metadata endpoint returned status %d: %s", resp.StatusCode, string(body)) + return nil, fmt.Errorf("metadata endpoint returned status %d: %s", resp.StatusCode, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Printf("OAuth ERROR: Failed to read metadata response: %v", err) + return nil, fmt.Errorf("failed to read metadata response: %w", err) + } + + var metadata AzureADMetadata + if err := json.Unmarshal(body, &metadata); err != nil { + log.Printf("OAuth ERROR: Failed to parse metadata JSON: %v", err) + return nil, fmt.Errorf("failed to parse metadata: %w", err) + } + + log.Printf("OAuth DEBUG: Successfully parsed Azure AD metadata, original grant_types_supported: %v", metadata.GrantTypesSupported) + + // Ensure grant_types_supported is populated for MCP Inspector compatibility + if len(metadata.GrantTypesSupported) == 0 { + log.Printf("OAuth DEBUG: Setting default grant_types_supported (was empty/nil)") + metadata.GrantTypesSupported = []string{"authorization_code", "refresh_token"} + } + + // Ensure response_types_supported is populated for MCP Inspector compatibility + if len(metadata.ResponseTypesSupported) == 0 { + log.Printf("OAuth DEBUG: Setting default response_types_supported (was empty/nil)") + metadata.ResponseTypesSupported = []string{"code"} + } + + // Ensure subject_types_supported is populated for MCP Inspector compatibility + if len(metadata.SubjectTypesSupported) == 0 { + log.Printf("OAuth DEBUG: Setting default subject_types_supported (was empty/nil)") + metadata.SubjectTypesSupported = []string{"public"} + } + + // Ensure token_endpoint_auth_methods_supported is populated for MCP Inspector compatibility + if len(metadata.TokenEndpointAuthMethodsSupported) == 0 { + log.Printf("OAuth DEBUG: Setting default token_endpoint_auth_methods_supported (was empty/nil)") + metadata.TokenEndpointAuthMethodsSupported = []string{"none"} + } + + // Add S256 code challenge method support (Azure AD supports this but may not advertise it) + // MCP specification requires S256 support, so we always ensure it's present + log.Printf("OAuth DEBUG: Enforcing S256 code challenge method support (MCP requirement)") + metadata.CodeChallengeMethodsSupported = []string{"S256"} + + // Azure AD v2.0 has limited support for RFC 8707 Resource Indicators + // - Authorization endpoint: doesn't support resource parameter + // - Token endpoint: doesn't support resource parameter + // - Uses scope-based resource identification instead + // Our proxy handles MCP resource parameter translation + parsedURL, err := url.Parse(serverURL) + if err == nil { + // If the server URL includes /mcp path, include it in the proxy endpoint + proxyPath := "/oauth2/v2.0/authorize" + tokenPath := "/oauth2/v2.0/token" + registrationPath := "/oauth/register" + proxyAuthURL := fmt.Sprintf("%s://%s%s", parsedURL.Scheme, parsedURL.Host, proxyPath) + tokenURL := fmt.Sprintf("%s://%s%s", parsedURL.Scheme, parsedURL.Host, tokenPath) + registrationURL := fmt.Sprintf("%s://%s%s", parsedURL.Scheme, parsedURL.Host, registrationPath) + + metadata.AuthorizationEndpoint = proxyAuthURL + metadata.TokenEndpoint = tokenURL + // Add dynamic client registration endpoint + metadata.RegistrationEndpoint = registrationURL + } + + log.Printf("OAuth DEBUG: Final metadata prepared - grant_types_supported: %v, response_types_supported: %v, code_challenge_methods_supported: %v", + metadata.GrantTypesSupported, metadata.ResponseTypesSupported, metadata.CodeChallengeMethodsSupported) + + return &metadata, nil +} + +// ValidateToken validates an OAuth access token +func (p *AzureOAuthProvider) ValidateToken(ctx context.Context, tokenString string) (*auth.TokenInfo, error) { + // Check if token looks like a valid JWT (should have exactly 2 dots) + dotCount := strings.Count(tokenString, ".") + if dotCount != 2 { + return nil, fmt.Errorf("invalid JWT token format: expected 3 parts separated by dots, got %d dots", dotCount) + } + + // If JWT validation is disabled, return a minimal token info without full validation + if !p.config.TokenValidation.ValidateJWT { + return &auth.TokenInfo{ + AccessToken: tokenString, + TokenType: "Bearer", + ExpiresAt: time.Now().Add(time.Hour), // Default 1 hour expiration + Scope: p.config.RequiredScopes, // Use configured scopes + Subject: "unknown", // Cannot extract without parsing + Audience: []string{p.config.TokenValidation.ExpectedAudience}, + Issuer: fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0", p.config.TenantID), + Claims: make(map[string]interface{}), + }, nil + } + + // Parse and validate JWT token + + // Parse token structure and check expiration + parserUnsafe := jwt.NewParser(jwt.WithoutClaimsValidation()) + tokenUnsafe, _, err := parserUnsafe.ParseUnverified(tokenString, jwt.MapClaims{}) + if err != nil { + return nil, fmt.Errorf("invalid token structure: %w", err) + } + + // Check claims and expiration + if claims, ok := tokenUnsafe.Claims.(jwt.MapClaims); ok { + if exp, ok := claims["exp"].(float64); ok { + expTime := time.Unix(int64(exp), 0) + if time.Now().After(expTime) { + return nil, fmt.Errorf("token expired at %v", expTime) + } + } + } + + // JWT signature validation + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + token, err := parser.ParseWithClaims(tokenString, jwt.MapClaims{}, p.getKeyFunc) + if err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + if !token.Valid { + return nil, fmt.Errorf("invalid token") + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, fmt.Errorf("invalid token claims") + } + + // Validate issuer - with Azure Management API scope, we should get v2.0 format + issuer, ok := claims["iss"].(string) + if !ok { + return nil, fmt.Errorf("missing issuer claim") + } + + expectedIssuerV2 := fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0", p.config.TenantID) + expectedIssuerV1 := fmt.Sprintf("https://sts.windows.net/%s/", p.config.TenantID) + + if issuer != expectedIssuerV2 && issuer != expectedIssuerV1 { + return nil, fmt.Errorf("invalid issuer: expected %s (preferred) or %s (fallback), got %s", expectedIssuerV2, expectedIssuerV1, issuer) + } + + // Azure AD may return v1.0 or v2.0 issuer format depending on token scope + + // Validate audience and resource binding + if p.config.TokenValidation.ValidateAudience { + if err := p.validateAudience(claims); err != nil { + return nil, err + } + } + + // Extract token information + tokenInfo := &auth.TokenInfo{ + AccessToken: tokenString, + TokenType: "Bearer", + Claims: claims, + } + + // Extract subject + if sub, ok := claims["sub"].(string); ok { + tokenInfo.Subject = sub + } + + // Extract audience + if aud, ok := claims["aud"].(string); ok { + tokenInfo.Audience = []string{aud} + } else if audSlice, ok := claims["aud"].([]interface{}); ok { + for _, a := range audSlice { + if audStr, ok := a.(string); ok { + tokenInfo.Audience = append(tokenInfo.Audience, audStr) + } + } + } + + // Extract scope from Azure AD token + // Check for 'scp' claim (Azure AD v2.0) + if scp, ok := claims["scp"].(string); ok { + tokenInfo.Scope = strings.Split(scp, " ") + } else if scope, ok := claims["scope"].(string); ok { + // Check for 'scope' claim (alternative) + tokenInfo.Scope = strings.Split(scope, " ") + } + + // Check for 'roles' claim (Azure AD app roles) + if roles, ok := claims["roles"].([]interface{}); ok { + for _, role := range roles { + if roleStr, ok := role.(string); ok { + tokenInfo.Scope = append(tokenInfo.Scope, roleStr) + } + } + } + + // Extract expiration + if exp, ok := claims["exp"].(float64); ok { + tokenInfo.ExpiresAt = time.Unix(int64(exp), 0) + } + + // Set issuer + tokenInfo.Issuer = issuer + + return tokenInfo, nil +} + +// validateAudience validates the audience claim and resource binding (RFC 8707) +func (p *AzureOAuthProvider) validateAudience(claims jwt.MapClaims) error { + expectedAudience := p.config.TokenValidation.ExpectedAudience + + // Normalize expected audience - remove trailing slash for comparison + normalizedExpected := strings.TrimSuffix(expectedAudience, "/") + + // Check single audience + if aud, ok := claims["aud"].(string); ok { + normalizedAud := strings.TrimSuffix(aud, "/") + if normalizedAud == normalizedExpected || aud == p.config.ClientID { + return nil + } + return fmt.Errorf("invalid audience: expected %s or %s, got %s", expectedAudience, p.config.ClientID, aud) + } + + // Check audience array + if audSlice, ok := claims["aud"].([]interface{}); ok { + for _, a := range audSlice { + if audStr, ok := a.(string); ok { + normalizedAud := strings.TrimSuffix(audStr, "/") + if normalizedAud == normalizedExpected || audStr == p.config.ClientID { + return nil + } + } + } + return fmt.Errorf("invalid audience: expected %s or %s in audience list", expectedAudience, p.config.ClientID) + } + + return fmt.Errorf("missing audience claim") +} + +// getKeyFunc returns a function to retrieve JWT signing keys +func (p *AzureOAuthProvider) getKeyFunc(token *jwt.Token) (interface{}, error) { + // Validate signing method + if token.Method.Alg() != "RS256" { + return nil, fmt.Errorf("unexpected signing method: expected RS256, got %v", token.Method.Alg()) + } + + // Also verify it's an RSA method + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("signing method is not RSA: %T", token.Method) + } + + // Get key ID from token header + kid, ok := token.Header["kid"].(string) + if !ok { + return nil, fmt.Errorf("missing key ID in token header") + } + + // Extract issuer from token to determine the correct JWKS endpoint + var issuer string + if claims, ok := token.Claims.(jwt.MapClaims); ok { + if iss, ok := claims["iss"].(string); ok { + issuer = iss + } + } + + // Get the public key for this key ID using the appropriate issuer + key, err := p.getPublicKey(kid, issuer) + if err != nil { + log.Printf("PUBLIC KEY RETRIEVAL FAILED: %s\n", err) + return nil, fmt.Errorf("failed to get public key: %w", err) + } + + return key, nil +} + +// getPublicKey retrieves and caches Azure AD public keys +func (p *AzureOAuthProvider) getPublicKey(kid string, issuer string) (*rsa.PublicKey, error) { + // Generate cache key based on both kid and issuer to avoid conflicts between v1.0 and v2.0 keys + cacheKey := fmt.Sprintf("%s_%s", kid, issuer) + + // Check cache first if caching is enabled + if p.enableCache { + p.keyCache.mu.RLock() + if key, exists := p.keyCache.keys[cacheKey]; exists && time.Now().Before(p.keyCache.expiresAt) { + p.keyCache.mu.RUnlock() + return key, nil + } + p.keyCache.mu.RUnlock() + } + + // With Azure Management API scope, we should always get v2.0 format tokens + // Force using v2.0 JWKS endpoint for consistency + jwksURL := fmt.Sprintf("https://login.microsoftonline.com/%s/discovery/v2.0/keys", p.config.TenantID) + + resp, err := p.httpClient.Get(jwksURL) + if err != nil { + return nil, fmt.Errorf("failed to fetch JWKS from %s: %w", jwksURL, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("JWKS endpoint returned status %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read JWKS response: %w", err) + } + + var jwks struct { + Keys []struct { + Kid string `json:"kid"` + N string `json:"n"` + E string `json:"e"` + Kty string `json:"kty"` + } `json:"keys"` + } + + if err := json.Unmarshal(body, &jwks); err != nil { + return nil, fmt.Errorf("failed to parse JWKS: %w", err) + } + + log.Printf("JWKS Contains %d keys, searching for kid=%s\n", len(jwks.Keys), kid) + + // Parse keys and find the target key + var targetKey *rsa.PublicKey + var foundKeyIds []string + + for _, key := range jwks.Keys { + foundKeyIds = append(foundKeyIds, key.Kid) + + if key.Kty == "RSA" && key.Kid == kid { + pubKey, err := parseRSAPublicKey(key.N, key.E) + if err != nil { + log.Printf("JWKS Failed to parse RSA key %s: %v\n", key.Kid, err) + continue + } + targetKey = pubKey + break + } + } + + // Cache the retrieved key and return it (only if caching is enabled) + if targetKey != nil { + if p.enableCache { + p.keyCache.mu.Lock() + if p.keyCache.keys == nil { + p.keyCache.keys = make(map[string]*rsa.PublicKey) + } + p.keyCache.keys[cacheKey] = targetKey + p.keyCache.expiresAt = time.Now().Add(24 * time.Hour) // Cache for 24 hours + p.keyCache.mu.Unlock() + } + return targetKey, nil + } + + return nil, fmt.Errorf("key with ID %s not found in JWKS (available: %v)", kid, foundKeyIds) +} + +// parseRSAPublicKey parses RSA public key from JWK format +func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) { + // Decode base64url-encoded modulus + nBytes, err := base64.RawURLEncoding.DecodeString(nStr) + if err != nil { + return nil, fmt.Errorf("failed to decode modulus: %w", err) + } + + // Decode base64url-encoded exponent + eBytes, err := base64.RawURLEncoding.DecodeString(eStr) + if err != nil { + return nil, fmt.Errorf("failed to decode exponent: %w", err) + } + + // Convert bytes to big integers + n := new(big.Int).SetBytes(nBytes) + e := new(big.Int).SetBytes(eBytes) + + // Create RSA public key + pubKey := &rsa.PublicKey{ + N: n, + E: int(e.Int64()), + } + + return pubKey, nil +} diff --git a/internal/auth/oauth/provider_test.go b/internal/auth/oauth/provider_test.go new file mode 100644 index 0000000..66e9824 --- /dev/null +++ b/internal/auth/oauth/provider_test.go @@ -0,0 +1,381 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Azure/aks-mcp/internal/auth" +) + +func TestNewAzureOAuthProvider(t *testing.T) { + tests := []struct { + name string + config *auth.OAuthConfig + wantErr bool + }{ + { + name: "valid config should create provider", + config: &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: true, + ValidateAudience: true, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + }, + wantErr: false, + }, + { + name: "invalid config should fail", + config: &auth.OAuthConfig{ + Enabled: true, + // Missing required fields + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := NewAzureOAuthProvider(tt.config) + if (err != nil) != tt.wantErr { + t.Errorf("NewAzureOAuthProvider() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && provider == nil { + t.Error("NewAzureOAuthProvider() returned nil provider") + } + }) + } +} + +func TestGetProtectedResourceMetadata(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant-id", + ClientID: "test-client-id", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: true, + ValidateAudience: true, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + serverURL := "http://localhost:8000" + metadata, err := provider.GetProtectedResourceMetadata(serverURL) + if err != nil { + t.Fatalf("GetProtectedResourceMetadata() error = %v", err) + } + + expectedAuthServer := "http://localhost:8000" + if len(metadata.AuthorizationServers) != 1 || metadata.AuthorizationServers[0] != expectedAuthServer { + t.Errorf("Expected authorization server %s, got %v", expectedAuthServer, metadata.AuthorizationServers) + } + + // Note: AzureADProtectedResourceMetadata doesn't include a Resource field. + // The resource URL is implied by the context of the request endpoint. + + if len(metadata.ScopesSupported) != 1 || metadata.ScopesSupported[0] != "https://management.azure.com/.default" { + t.Errorf("Expected scopes %v, got %v", config.RequiredScopes, metadata.ScopesSupported) + } +} + +func TestGetAuthorizationServerMetadataWithDefaults(t *testing.T) { + // Create a mock Azure AD metadata endpoint that's missing some fields + // This simulates the case where Azure AD doesn't provide all required fields + mockMetadata := AzureADMetadata{ + Issuer: "https://login.microsoftonline.com/test-tenant/v2.0", + AuthorizationEndpoint: "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/authorize", + TokenEndpoint: "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/token", + JWKSUri: "https://login.microsoftonline.com/test-tenant/discovery/v2.0/keys", + ScopesSupported: []string{"openid", "profile", "email"}, + // Intentionally omit GrantTypesSupported, ResponseTypesSupported, etc. + // to test our default value logic + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockMetadata) + })) + defer server.Close() + + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: true, + ValidateAudience: true, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + // Override the HTTP client to use our test server + provider.httpClient = &http.Client{ + Transport: &roundTripperFunc{ + fn: func(req *http.Request) (*http.Response, error) { + // Redirect all requests to our test server + req.URL.Scheme = "http" + req.URL.Host = server.URL[7:] // Remove "http://" + req.URL.Path = "/" + return http.DefaultTransport.RoundTrip(req) + }, + }, + } + + metadata, err := provider.GetAuthorizationServerMetadata(server.URL) + if err != nil { + t.Fatalf("GetAuthorizationServerMetadata() error = %v", err) + } + + // Verify that default values were populated for missing fields + expectedGrantTypes := []string{"authorization_code", "refresh_token"} + if len(metadata.GrantTypesSupported) != len(expectedGrantTypes) { + t.Errorf("Expected %d grant types, got %d", len(expectedGrantTypes), len(metadata.GrantTypesSupported)) + } + for i, expected := range expectedGrantTypes { + if i >= len(metadata.GrantTypesSupported) || metadata.GrantTypesSupported[i] != expected { + t.Errorf("Expected grant type %s at index %d, got %v", expected, i, metadata.GrantTypesSupported) + } + } + + expectedResponseTypes := []string{"code"} + if len(metadata.ResponseTypesSupported) != len(expectedResponseTypes) { + t.Errorf("Expected %d response types, got %d", len(expectedResponseTypes), len(metadata.ResponseTypesSupported)) + } + if len(metadata.ResponseTypesSupported) > 0 && metadata.ResponseTypesSupported[0] != "code" { + t.Errorf("Expected response type 'code', got %s", metadata.ResponseTypesSupported[0]) + } + + expectedSubjectTypes := []string{"public"} + if len(metadata.SubjectTypesSupported) != len(expectedSubjectTypes) { + t.Errorf("Expected %d subject types, got %d", len(expectedSubjectTypes), len(metadata.SubjectTypesSupported)) + } + if len(metadata.SubjectTypesSupported) > 0 && metadata.SubjectTypesSupported[0] != "public" { + t.Errorf("Expected subject type 'public', got %s", metadata.SubjectTypesSupported[0]) + } + + expectedTokenEndpointAuthMethods := []string{"none"} + if len(metadata.TokenEndpointAuthMethodsSupported) != len(expectedTokenEndpointAuthMethods) { + t.Errorf("Expected %d auth methods, got %d", len(expectedTokenEndpointAuthMethods), len(metadata.TokenEndpointAuthMethodsSupported)) + } + if len(metadata.TokenEndpointAuthMethodsSupported) > 0 && metadata.TokenEndpointAuthMethodsSupported[0] != "none" { + t.Errorf("Expected auth method 'none', got %s", metadata.TokenEndpointAuthMethodsSupported[0]) + } + + // Verify that PKCE is properly configured + expectedCodeChallengeMethods := []string{"S256"} + if len(metadata.CodeChallengeMethodsSupported) != len(expectedCodeChallengeMethods) { + t.Errorf("Expected %d code challenge methods, got %d", len(expectedCodeChallengeMethods), len(metadata.CodeChallengeMethodsSupported)) + } + if len(metadata.CodeChallengeMethodsSupported) > 0 && metadata.CodeChallengeMethodsSupported[0] != "S256" { + t.Errorf("Expected code challenge method 'S256', got %s", metadata.CodeChallengeMethodsSupported[0]) + } +} + +func TestGetAuthorizationServerMetadata(t *testing.T) { + // Create a mock Azure AD metadata endpoint + mockMetadata := AzureADMetadata{ + Issuer: "https://login.microsoftonline.com/test-tenant/v2.0", + AuthorizationEndpoint: "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/authorize", + TokenEndpoint: "https://login.microsoftonline.com/test-tenant/oauth2/v2.0/token", + JWKSUri: "https://login.microsoftonline.com/test-tenant/discovery/v2.0/keys", + ScopesSupported: []string{"openid", "profile", "email"}, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockMetadata) + })) + defer server.Close() + + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: true, + ValidateAudience: true, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + // Override the HTTP client to use our test server + provider.httpClient = &http.Client{ + Transport: &roundTripperFunc{ + fn: func(req *http.Request) (*http.Response, error) { + // Redirect all requests to our test server + req.URL.Scheme = "http" + req.URL.Host = server.URL[7:] // Remove "http://" + req.URL.Path = "/" + return http.DefaultTransport.RoundTrip(req) + }, + }, + } + + metadata, err := provider.GetAuthorizationServerMetadata(server.URL) + if err != nil { + t.Fatalf("GetAuthorizationServerMetadata() error = %v", err) + } + + if metadata.Issuer != mockMetadata.Issuer { + t.Errorf("Expected issuer %s, got %s", mockMetadata.Issuer, metadata.Issuer) + } + + expectedAuthEndpoint := fmt.Sprintf("%s/oauth2/v2.0/authorize", server.URL) + if metadata.AuthorizationEndpoint != expectedAuthEndpoint { + t.Errorf("Expected auth endpoint %s, got %s", expectedAuthEndpoint, metadata.AuthorizationEndpoint) + } +} + +func TestValidateTokenWithoutJWT(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: false, // Disable JWT validation + ValidateAudience: false, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + ctx := context.Background() + // Use a token that looks like a JWT to pass initial format checks + testToken := "header.payload.signature" + tokenInfo, err := provider.ValidateToken(ctx, testToken) + if err != nil { + t.Fatalf("ValidateToken() error = %v", err) + } + + if tokenInfo.AccessToken != testToken { + t.Errorf("Expected access token %s, got %s", testToken, tokenInfo.AccessToken) + } + + if tokenInfo.TokenType != "Bearer" { + t.Errorf("Expected token type Bearer, got %s", tokenInfo.TokenType) + } +} + +func TestValidateAudience(t *testing.T) { + config := &auth.OAuthConfig{ + Enabled: true, + TenantID: "test-tenant", + ClientID: "test-client-id", + RequiredScopes: []string{"https://management.azure.com/.default"}, + TokenValidation: auth.TokenValidationConfig{ + ValidateJWT: true, + ValidateAudience: true, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: 5 * time.Minute, + ClockSkew: 1 * time.Minute, + }, + } + + provider, err := NewAzureOAuthProvider(config) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + tests := []struct { + name string + claims map[string]interface{} + wantErr bool + }{ + { + name: "valid audience string", + claims: map[string]interface{}{ + "aud": "https://management.azure.com/", + }, + wantErr: false, + }, + { + name: "valid client ID audience", + claims: map[string]interface{}{ + "aud": "test-client-id", + }, + wantErr: false, + }, + { + name: "valid audience array", + claims: map[string]interface{}{ + "aud": []interface{}{"https://management.azure.com/", "other-aud"}, + }, + wantErr: false, + }, + { + name: "invalid audience", + claims: map[string]interface{}{ + "aud": "invalid-audience", + }, + wantErr: true, + }, + { + name: "missing audience", + claims: map[string]interface{}{ + "sub": "user123", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := provider.validateAudience(tt.claims) + if (err != nil) != tt.wantErr { + t.Errorf("validateAudience() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// roundTripperFunc is a helper type for creating custom HTTP transports in tests +type roundTripperFunc struct { + fn func(*http.Request) (*http.Response, error) +} + +func (f *roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f.fn(req) +} diff --git a/internal/auth/types.go b/internal/auth/types.go new file mode 100644 index 0000000..f882a95 --- /dev/null +++ b/internal/auth/types.go @@ -0,0 +1,130 @@ +package auth + +import ( + "fmt" + "time" +) + +// OAuthConfig represents OAuth configuration for AKS-MCP +type OAuthConfig struct { + // Enable OAuth authentication + Enabled bool `json:"enabled"` + + // Azure AD tenant ID + TenantID string `json:"tenant_id"` + + // Azure AD application (client) ID + ClientID string `json:"client_id"` + + // Required OAuth scopes for accessing AKS-MCP + RequiredScopes []string `json:"required_scopes"` + + // Token validation settings + TokenValidation TokenValidationConfig `json:"token_validation"` +} + +// TokenValidationConfig represents token validation configuration +type TokenValidationConfig struct { + // Enable JWT token validation + ValidateJWT bool `json:"validate_jwt"` + + // Enable audience validation + ValidateAudience bool `json:"validate_audience"` + + // Expected audience for tokens + ExpectedAudience string `json:"expected_audience"` + + // Token cache TTL + CacheTTL time.Duration `json:"cache_ttl"` + + // Clock skew tolerance for token validation + ClockSkew time.Duration `json:"clock_skew"` +} + +// TokenInfo represents validated token information +type TokenInfo struct { + // Access token + AccessToken string `json:"access_token"` + + // Token type (usually "Bearer") + TokenType string `json:"token_type"` + + // Token expiration time + ExpiresAt time.Time `json:"expires_at"` + + // Token scope + Scope []string `json:"scope"` + + // Subject (user ID) + Subject string `json:"subject"` + + // Audience + Audience []string `json:"audience"` + + // Issuer + Issuer string `json:"issuer"` + + // Additional claims + Claims map[string]interface{} `json:"claims"` +} + +// AuthResult represents the result of authentication +type AuthResult struct { + // Whether authentication was successful + Authenticated bool `json:"authenticated"` + + // Token information (if authenticated) + TokenInfo *TokenInfo `json:"token_info,omitempty"` + + // Error message (if authentication failed) + Error string `json:"error,omitempty"` + + // HTTP status code to return + StatusCode int `json:"status_code"` +} + +// Default OAuth configuration values +const ( + DefaultTokenCacheTTL = 5 * time.Minute + DefaultClockSkew = 1 * time.Minute + DefaultExpectedAudience = "https://management.azure.com" + AzureADScope = "https://management.azure.com/.default" +) + +// NewDefaultOAuthConfig creates a default OAuth configuration +func NewDefaultOAuthConfig() *OAuthConfig { + return &OAuthConfig{ + Enabled: false, + // Use Azure Management API scope to get v2.0 format tokens + // This ensures we get v2.0 issuer format which works with v2.0 JWKS endpoints + RequiredScopes: []string{AzureADScope}, // "https://management.azure.com/.default" + TokenValidation: TokenValidationConfig{ + ValidateJWT: true, // Re-enabled now that we'll get v2.0 tokens + ValidateAudience: true, // Re-enabled with correct audience + ExpectedAudience: DefaultExpectedAudience, // "https://management.azure.com" + CacheTTL: DefaultTokenCacheTTL, + ClockSkew: DefaultClockSkew, + }, + } +} + +// Validate validates the OAuth configuration +func (cfg *OAuthConfig) Validate() error { + if !cfg.Enabled { + return nil + } + + if cfg.TenantID == "" { + return fmt.Errorf("tenant_id is required when OAuth is enabled") + } + + if cfg.ClientID == "" { + return fmt.Errorf("client_id is required when OAuth is enabled") + } + + // if len(cfg.RequiredScopes) == 0 { + // return fmt.Errorf("at least one required scope must be specified") + // } + + return nil +} diff --git a/internal/auth/types_test.go b/internal/auth/types_test.go new file mode 100644 index 0000000..fb84dcc --- /dev/null +++ b/internal/auth/types_test.go @@ -0,0 +1,175 @@ +package auth + +import ( + "os" + "testing" + "time" +) + +func TestOAuthConfigValidation(t *testing.T) { + tests := []struct { + name string + config *OAuthConfig + wantErr bool + }{ + { + name: "disabled OAuth should pass validation", + config: &OAuthConfig{ + Enabled: false, + }, + wantErr: false, + }, + { + name: "enabled OAuth with missing tenant ID should fail", + config: &OAuthConfig{ + Enabled: true, + ClientID: "test-client-id", + RequiredScopes: []string{"scope1"}, + }, + wantErr: true, + }, + { + name: "enabled OAuth with missing client ID should fail", + config: &OAuthConfig{ + Enabled: true, + TenantID: "test-tenant-id", + RequiredScopes: []string{"scope1"}, + }, + wantErr: true, + }, + { + name: "enabled OAuth with empty scopes should pass", + config: &OAuthConfig{ + Enabled: true, + TenantID: "test-tenant-id", + ClientID: "test-client-id", + RequiredScopes: []string{}, + }, + wantErr: false, + }, + { + name: "valid enabled OAuth config should pass", + config: &OAuthConfig{ + Enabled: true, + TenantID: "test-tenant-id", + ClientID: "test-client-id", + RequiredScopes: []string{"scope1"}, + }, + wantErr: false, + }, + { + name: "valid enabled OAuth config with full token validation should pass", + config: &OAuthConfig{ + Enabled: true, + TenantID: "test-tenant-id", + ClientID: "test-client-id", + RequiredScopes: []string{"scope1"}, + TokenValidation: TokenValidationConfig{ + ValidateJWT: true, + ValidateAudience: true, + ExpectedAudience: "https://management.azure.com/", + CacheTTL: DefaultTokenCacheTTL, + ClockSkew: DefaultClockSkew, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("OAuthConfig.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestNewDefaultOAuthConfig(t *testing.T) { + config := NewDefaultOAuthConfig() + + if config.Enabled { + t.Error("Default config should have OAuth disabled") + } + + if len(config.RequiredScopes) != 1 || config.RequiredScopes[0] != AzureADScope { + t.Errorf("Default config should have Azure AD scope, got %v", config.RequiredScopes) + } + + if !config.TokenValidation.ValidateJWT { + t.Error("Default config should enable JWT validation for production") + } + + if !config.TokenValidation.ValidateAudience { + t.Error("Default config should enable audience validation for security") + } + + if config.TokenValidation.ExpectedAudience != DefaultExpectedAudience { + t.Errorf("Default config should have correct expected audience, got %s", config.TokenValidation.ExpectedAudience) + } + + if config.TokenValidation.CacheTTL != DefaultTokenCacheTTL { + t.Errorf("Default config should have correct cache TTL, got %v", config.TokenValidation.CacheTTL) + } + + if config.TokenValidation.ClockSkew != DefaultClockSkew { + t.Errorf("Default config should have correct clock skew, got %v", config.TokenValidation.ClockSkew) + } +} + +func TestOAuthConfigConstants(t *testing.T) { + if DefaultTokenCacheTTL != 5*time.Minute { + t.Errorf("DefaultTokenCacheTTL should be 5 minutes, got %v", DefaultTokenCacheTTL) + } + + if DefaultClockSkew != 1*time.Minute { + t.Errorf("DefaultClockSkew should be 1 minute, got %v", DefaultClockSkew) + } + + if DefaultExpectedAudience != "https://management.azure.com" { + t.Errorf("DefaultExpectedAudience should be Azure management, got %s", DefaultExpectedAudience) + } + + if AzureADScope != "https://management.azure.com/.default" { + t.Errorf("AzureADScope should be Azure management default, got %s", AzureADScope) + } +} + +func TestOAuthConfigEnvironmentVariables(t *testing.T) { + // Test that environment variables are respected + oldTenantID := os.Getenv("AZURE_TENANT_ID") + oldClientID := os.Getenv("AZURE_CLIENT_ID") + + defer func() { + os.Setenv("AZURE_TENANT_ID", oldTenantID) + os.Setenv("AZURE_CLIENT_ID", oldClientID) + }() + + os.Setenv("AZURE_TENANT_ID", "env-tenant-id") + os.Setenv("AZURE_CLIENT_ID", "env-client-id") + + config := NewDefaultOAuthConfig() + config.Enabled = true + + // Simulate the environment variable loading that happens in config parsing + if config.TenantID == "" { + config.TenantID = os.Getenv("AZURE_TENANT_ID") + } + if config.ClientID == "" { + config.ClientID = os.Getenv("AZURE_CLIENT_ID") + } + + if config.TenantID != "env-tenant-id" { + t.Errorf("Expected tenant ID from environment, got %s", config.TenantID) + } + + if config.ClientID != "env-client-id" { + t.Errorf("Expected client ID from environment, got %s", config.ClientID) + } + + // Should pass validation with environment variables + if err := config.Validate(); err != nil { + t.Errorf("Config with environment variables should be valid, got error: %v", err) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index a653188..55c175a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,12 +8,18 @@ import ( "strings" "time" + "github.com/Azure/aks-mcp/internal/auth" "github.com/Azure/aks-mcp/internal/security" "github.com/Azure/aks-mcp/internal/telemetry" "github.com/Azure/aks-mcp/internal/version" flag "github.com/spf13/pflag" ) +// EnableCache controls whether caching is enabled globally +// Set to false during debugging to avoid cache-related issues +// This affects both web cache headers and AzureOAuthProvider cache +const EnableCache = false + // ConfigData holds the global configuration type ConfigData struct { // Command execution timeout in seconds @@ -22,6 +28,8 @@ type ConfigData struct { CacheTimeout time.Duration // Security configuration SecurityConfig *security.SecurityConfig + // OAuth configuration + OAuthConfig *auth.OAuthConfig // Command-line specific options Transport string @@ -51,6 +59,7 @@ func NewConfig() *ConfigData { Timeout: 60, CacheTimeout: 1 * time.Minute, SecurityConfig: security.NewSecurityConfig(), + OAuthConfig: auth.NewDefaultOAuthConfig(), Transport: "stdio", Port: 8000, AccessLevel: "readonly", @@ -66,9 +75,15 @@ func (cfg *ConfigData) ParseFlags() { flag.StringVar(&cfg.Host, "host", "127.0.0.1", "Host to listen for the server (only used with transport sse or streamable-http)") flag.IntVar(&cfg.Port, "port", 8000, "Port to listen for the server (only used with transport sse or streamable-http)") flag.IntVar(&cfg.Timeout, "timeout", 600, "Timeout for command execution in seconds, default is 600s") + // Security settings flag.StringVar(&cfg.AccessLevel, "access-level", "readonly", "Access level (readonly, readwrite, admin)") + // OAuth configuration + flag.BoolVar(&cfg.OAuthConfig.Enabled, "oauth-enabled", false, "Enable OAuth authentication") + flag.StringVar(&cfg.OAuthConfig.TenantID, "oauth-tenant-id", "", "Azure AD tenant ID for OAuth (fallback to AZURE_TENANT_ID env var)") + flag.StringVar(&cfg.OAuthConfig.ClientID, "oauth-client-id", "", "Azure AD client ID for OAuth (fallback to AZURE_CLIENT_ID env var)") + // Kubernetes-specific settings additionalTools := flag.String("additional-tools", "", "Comma-separated list of additional Kubernetes tools to support (kubectl is always enabled). Available: helm,cilium,hubble") @@ -113,6 +128,9 @@ func (cfg *ConfigData) ParseFlags() { cfg.SecurityConfig.AccessLevel = cfg.AccessLevel cfg.SecurityConfig.AllowedNamespaces = cfg.AllowNamespaces + // Parse OAuth configuration + cfg.parseOAuthConfig() + // Parse additional tools if *additionalTools != "" { tools := strings.Split(*additionalTools, ",") @@ -122,6 +140,30 @@ func (cfg *ConfigData) ParseFlags() { } } +// parseOAuthConfig parses OAuth-related command line arguments +func (cfg *ConfigData) parseOAuthConfig() { + // Note: OAuth scopes are automatically configured to use "https://management.azure.com/.default" + // and are not configurable via command line per design + + // Load OAuth configuration from environment variables if not set via CLI + if cfg.OAuthConfig.TenantID == "" { + cfg.OAuthConfig.TenantID = os.Getenv("AZURE_TENANT_ID") + } + if cfg.OAuthConfig.ClientID == "" { + cfg.OAuthConfig.ClientID = os.Getenv("AZURE_CLIENT_ID") + } +} + +// ValidateConfig validates the configuration for incompatible settings +func (cfg *ConfigData) ValidateConfig() error { + // Validate OAuth + transport compatibility + if cfg.OAuthConfig.Enabled && cfg.Transport == "stdio" { + return fmt.Errorf("OAuth authentication is not supported with stdio transport per MCP specification") + } + + return nil +} + // InitializeTelemetry initializes the telemetry service func (cfg *ConfigData) InitializeTelemetry(ctx context.Context, serviceName, serviceVersion string) { // Create telemetry configuration diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..a70c989 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,27 @@ +package config + +import ( + "testing" +) + +func TestBasicOAuthConfig(t *testing.T) { + // Test basic OAuth configuration parsing + cfg := NewConfig() + cfg.OAuthConfig.Enabled = true + cfg.OAuthConfig.TenantID = "test-tenant" + cfg.OAuthConfig.ClientID = "test-client" + + // Parse OAuth configuration + cfg.parseOAuthConfig() + + // Verify basic configuration is preserved + if !cfg.OAuthConfig.Enabled { + t.Error("Expected OAuth to be enabled") + } + if cfg.OAuthConfig.TenantID != "test-tenant" { + t.Errorf("Expected tenant ID 'test-tenant', got %s", cfg.OAuthConfig.TenantID) + } + if cfg.OAuthConfig.ClientID != "test-client" { + t.Errorf("Expected client ID 'test-client', got %s", cfg.OAuthConfig.ClientID) + } +} diff --git a/internal/config/validator.go b/internal/config/validator.go index 3499192..2f5214f 100644 --- a/internal/config/validator.go +++ b/internal/config/validator.go @@ -64,12 +64,22 @@ func (v *Validator) validateCli() bool { return valid } +// validateConfig checks configuration compatibility +func (v *Validator) validateConfig() bool { + if err := v.config.ValidateConfig(); err != nil { + v.errors = append(v.errors, err.Error()) + return false + } + return true +} + // Validate runs all validation checks func (v *Validator) Validate() bool { // Run all validation checks validCli := v.validateCli() + validConfig := v.validateConfig() - return validCli + return validCli && validConfig } // GetErrors returns all errors found during validation diff --git a/internal/server/server.go b/internal/server/server.go index 555ca37..fb3c485 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -7,6 +7,7 @@ import ( "net/http" "time" + "github.com/Azure/aks-mcp/internal/auth/oauth" "github.com/Azure/aks-mcp/internal/azcli" "github.com/Azure/aks-mcp/internal/azureclient" "github.com/Azure/aks-mcp/internal/components/advisor" @@ -36,6 +37,9 @@ type Service struct { mcpServer *server.MCPServer azClient *azureclient.AzureClient azcliProcFactory func(timeout int) azcli.Proc + oauthProvider *oauth.AzureOAuthProvider + authMiddleware *oauth.AuthMiddleware + endpointManager *oauth.EndpointManager } // ServiceOption defines a function that configures the AKS MCP service @@ -83,6 +87,14 @@ func (s *Service) initializeInfrastructure() error { s.azClient = azClient log.Println("Azure client initialized successfully") + // Initialize OAuth components if enabled and transport is not stdio + // OAuth is not supported with stdio transport per MCP specification + if s.cfg.OAuthConfig.Enabled && s.cfg.Transport != "stdio" { + if err := s.initializeOAuth(); err != nil { + return fmt.Errorf("failed to initialize OAuth: %w", err) + } + } + // Ensure Azure CLI exists and is logged in if s.azcliProcFactory != nil { // Use injected factory to create an azcli.Proc @@ -114,6 +126,35 @@ func (s *Service) initializeInfrastructure() error { return nil } +// initializeOAuth initializes OAuth authentication components +func (s *Service) initializeOAuth() error { + log.Println("Initializing OAuth authentication...") + + // Validate OAuth configuration + if err := s.cfg.OAuthConfig.Validate(); err != nil { + return fmt.Errorf("invalid OAuth configuration: %w", err) + } + + // Create OAuth provider + provider, err := oauth.NewAzureOAuthProvider(s.cfg.OAuthConfig) + if err != nil { + return fmt.Errorf("failed to create OAuth provider: %w", err) + } + s.oauthProvider = provider + + // Create server URL for OAuth metadata + serverURL := fmt.Sprintf("http://%s:%d", s.cfg.Host, s.cfg.Port) + + // Create auth middleware + s.authMiddleware = oauth.NewAuthMiddleware(provider, serverURL) + + // Create endpoint manager + s.endpointManager = oauth.NewEndpointManager(provider, s.cfg.OAuthConfig) + + log.Printf("OAuth authentication initialized with tenant: %s", s.cfg.OAuthConfig.TenantID) + return nil +} + // registerAllComponents registers all component tools organized by category func (s *Service) registerAllComponents() { // Azure Components @@ -142,6 +183,15 @@ func (s *Service) registerPrompts() { func (s *Service) createCustomHTTPServerWithHelp404(addr string) *http.Server { mux := http.NewServeMux() + // Register OAuth endpoints if OAuth is enabled + if s.cfg.OAuthConfig.Enabled { + if s.endpointManager == nil { + log.Fatal("OAuth is enabled but endpoint manager is not initialized - this indicates a bug in server initialization") + } + log.Println("Registering OAuth endpoints...") + s.endpointManager.RegisterEndpoints(mux) + } + // Handle all other paths with a helpful 404 response mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/mcp" { @@ -159,6 +209,20 @@ func (s *Service) createCustomHTTPServerWithHelp404(addr string) *http.Server { }, } + // Add OAuth endpoints to the response if enabled + if s.cfg.OAuthConfig.Enabled { + oauthEndpoints := map[string]string{ + "oauth-metadata": "GET /.well-known/oauth-protected-resource - OAuth metadata", + "auth-server-metadata": "GET /.well-known/oauth-authorization-server - Authorization server metadata", + "client-registration": "POST /oauth/register - Dynamic client registration", + "token-introspection": "POST /oauth/introspect - Token introspection", + "health": "GET /health - Health check", + } + for k, v := range oauthEndpoints { + response["endpoints"].(map[string]string)[k] = v + } + } + if err := json.NewEncoder(w).Encode(response); err != nil { http.Error(w, "Failed to encode response", http.StatusInternalServerError) } @@ -177,9 +241,28 @@ func (s *Service) createCustomHTTPServerWithHelp404(addr string) *http.Server { func (s *Service) createCustomSSEServerWithHelp404(sseServer *server.SSEServer, addr string) *http.Server { mux := http.NewServeMux() - // Register SSE and Message handlers - mux.Handle("/sse", sseServer.SSEHandler()) - mux.Handle("/message", sseServer.MessageHandler()) + // Register OAuth endpoints if OAuth is enabled + if s.cfg.OAuthConfig.Enabled { + if s.endpointManager == nil { + log.Fatal("OAuth is enabled but endpoint manager is not initialized - this indicates a bug in server initialization") + } + log.Println("Registering OAuth endpoints for SSE server...") + s.endpointManager.RegisterEndpoints(mux) + } + + // Register SSE and Message handlers with authentication if enabled + if s.cfg.OAuthConfig.Enabled { + if s.authMiddleware == nil { + log.Fatal("OAuth is enabled but auth middleware is not initialized - this indicates a bug in server initialization") + } + // Apply authentication middleware to SSE and Message endpoints + mux.Handle("/sse", s.authMiddleware.Middleware(sseServer.SSEHandler())) + mux.Handle("/message", s.authMiddleware.Middleware(sseServer.MessageHandler())) + } else { + // Register without authentication + mux.Handle("/sse", sseServer.SSEHandler()) + mux.Handle("/message", sseServer.MessageHandler()) + } // Handle all other paths with a helpful 404 response mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { @@ -196,6 +279,26 @@ func (s *Service) createCustomSSEServerWithHelp404(sseServer *server.SSEServer, }, } + // Add OAuth endpoints and authentication info if enabled + if s.cfg.OAuthConfig.Enabled { + response["authentication"] = map[string]interface{}{ + "required": true, + "type": "Bearer", + "note": "Include 'Authorization: Bearer ' header for authenticated endpoints", + } + + oauthEndpoints := map[string]string{ + "oauth-metadata": "GET /.well-known/oauth-protected-resource - OAuth metadata", + "auth-server-metadata": "GET /.well-known/oauth-authorization-server - Authorization server metadata", + "client-registration": "POST /oauth/register - Dynamic client registration", + "token-introspection": "POST /oauth/introspect - Token introspection", + "health": "GET /health - Health check", + } + for k, v := range oauthEndpoints { + response["endpoints"].(map[string]string)[k] = v + } + } + if err := json.NewEncoder(w).Encode(response); err != nil { http.Error(w, "Failed to encode response", http.StatusInternalServerError) } @@ -231,6 +334,10 @@ func (s *Service) Run() error { log.Printf("SSE endpoint available at: http://%s/sse", addr) log.Printf("Message endpoint available at: http://%s/message", addr) log.Printf("Connect to /sse for real-time events, send JSON-RPC to /message") + if s.cfg.OAuthConfig.Enabled { + log.Printf("OAuth authentication enabled - Bearer token required for SSE and Message endpoints") + log.Printf("OAuth metadata available at: http://%s/.well-known/oauth-protected-resource", addr) + } return customServer.ListenAndServe() case "streamable-http": @@ -247,12 +354,25 @@ func (s *Service) Run() error { // Update the mux to use the actual streamable server as the MCP handler if mux, ok := customServer.Handler.(*http.ServeMux); ok { - mux.Handle("/mcp", streamableServer) + if s.cfg.OAuthConfig.Enabled { + if s.authMiddleware == nil { + log.Fatal("OAuth is enabled but auth middleware is not initialized - this indicates a bug in server initialization") + } + // Apply authentication middleware to MCP endpoint + mux.Handle("/mcp", s.authMiddleware.Middleware(streamableServer)) + } else { + // Register without authentication + mux.Handle("/mcp", streamableServer) + } } log.Printf("Streamable HTTP server listening on %s", addr) log.Printf("MCP endpoint available at: http://%s/mcp", addr) log.Printf("Send POST requests to /mcp to initialize session and obtain Mcp-Session-Id") + if s.cfg.OAuthConfig.Enabled { + log.Printf("OAuth authentication enabled - Bearer token required for MCP endpoint") + log.Printf("OAuth metadata available at: http://%s/.well-known/oauth-protected-resource", addr) + } return customServer.ListenAndServe() default: From 4c550e8f441499810c74bab3d305ad30ddba1fbd Mon Sep 17 00:00:00 2001 From: Guoxun Wei Date: Mon, 8 Sep 2025 16:54:24 +0800 Subject: [PATCH 2/4] lint --- internal/auth/oauth/endpoints.go | 67 +++++++++++++++++++++++--- internal/auth/oauth/middleware_test.go | 4 +- internal/auth/oauth/provider.go | 14 ++++-- internal/auth/oauth/provider_test.go | 8 ++- internal/auth/types_test.go | 16 ++++-- 5 files changed, 91 insertions(+), 18 deletions(-) diff --git a/internal/auth/oauth/endpoints.go b/internal/auth/oauth/endpoints.go index c4217b9..de31085 100644 --- a/internal/auth/oauth/endpoints.go +++ b/internal/auth/oauth/endpoints.go @@ -16,6 +16,31 @@ import ( "github.com/Azure/aks-mcp/internal/config" ) +// validateAzureADURL validates that the URL is a legitimate Azure AD endpoint +func validateAzureADURL(tokenURL string) error { + parsedURL, err := url.Parse(tokenURL) + if err != nil { + return fmt.Errorf("invalid URL format: %w", err) + } + + // Only allow HTTPS for security + if parsedURL.Scheme != "https" { + return fmt.Errorf("only HTTPS URLs are allowed") + } + + // Only allow Azure AD endpoints + if parsedURL.Host != "login.microsoftonline.com" { + return fmt.Errorf("only Azure AD endpoints are allowed") + } + + // Validate path format for token endpoint (should be /{tenantId}/oauth2/v2.0/token) + if !strings.Contains(parsedURL.Path, "/oauth2/v2.0/token") { + return fmt.Errorf("invalid Azure AD token endpoint path") + } + + return nil +} + // EndpointManager manages OAuth-related HTTP endpoints type EndpointManager struct { provider *AzureOAuthProvider @@ -296,7 +321,9 @@ func (em *EndpointManager) tokenIntrospectionHandler() http.HandlerFunc { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + log.Printf("Failed to encode introspection response: %v", err) + } return } @@ -419,7 +446,9 @@ func (em *EndpointManager) writeErrorResponse(w http.ResponseWriter, errorCode, "error_description": description, } - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + log.Printf("Failed to encode error response: %v", err) + } } // authorizationProxyHandler proxies authorization requests to Azure AD with resource parameter filtering @@ -588,6 +617,11 @@ func (em *EndpointManager) exchangeCodeForToken(code, state string) (*TokenRespo // Prepare token exchange request tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", em.config.TenantID) + // Validate URL for security + if err := validateAzureADURL(tokenURL); err != nil { + return nil, fmt.Errorf("invalid token URL: %w", err) + } + // Use default callback redirect URI for token exchange redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", 8080) // Default for MCP servers @@ -604,11 +638,15 @@ func (em *EndpointManager) exchangeCodeForToken(code, state string) (*TokenRespo // For MCP compliance, we handle resource binding through audience validation // Make token exchange request - resp, err := http.PostForm(tokenURL, data) + resp, err := http.PostForm(tokenURL, data) // #nosec G107 -- URL is validated above if err != nil { return nil, fmt.Errorf("token exchange request failed: %w", err) } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + log.Printf("Failed to close response body: %v", err) + } + }() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -649,7 +687,9 @@ func (em *EndpointManager) writeCallbackErrorResponse(w http.ResponseWriter, mes `, message) - w.Write([]byte(html)) + if _, err := w.Write([]byte(html)); err != nil { + log.Printf("Failed to write error response: %v", err) + } } // writeCallbackSuccessResponse writes a success response for callback @@ -733,7 +773,9 @@ func (em *EndpointManager) writeCallbackSuccessResponse(w http.ResponseWriter, t tokenResponse.AccessToken, tokenResponse.AccessToken) - w.Write([]byte(html)) + if _, err := w.Write([]byte(html)); err != nil { + log.Printf("Failed to write success response: %v", err) + } } // isValidClientID validates if a client ID is acceptable @@ -870,6 +912,11 @@ func (em *EndpointManager) exchangeCodeForTokenDirect(code, redirectURI, codeVer // Prepare token exchange request to Azure AD tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", em.config.TenantID) + // Validate URL for security + if err := validateAzureADURL(tokenURL); err != nil { + return nil, fmt.Errorf("invalid token URL: %w", err) + } + // Prepare form data data := url.Values{} data.Set("grant_type", "authorization_code") @@ -892,11 +939,15 @@ func (em *EndpointManager) exchangeCodeForTokenDirect(code, redirectURI, codeVer log.Printf("Azure AD token request with scope: %s", scope) // Make token exchange request to Azure AD - resp, err := http.PostForm(tokenURL, data) + resp, err := http.PostForm(tokenURL, data) // #nosec G107 -- URL is validated above if err != nil { return nil, fmt.Errorf("token exchange request failed: %w", err) } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + log.Printf("Failed to close response body: %v", err) + } + }() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) diff --git a/internal/auth/oauth/middleware_test.go b/internal/auth/oauth/middleware_test.go index ed73b4a..358cb1e 100644 --- a/internal/auth/oauth/middleware_test.go +++ b/internal/auth/oauth/middleware_test.go @@ -44,7 +44,9 @@ func TestAuthMiddleware(t *testing.T) { // Create a test handler testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - w.Write([]byte("success")) + if _, err := w.Write([]byte("success")); err != nil { + t.Errorf("Failed to write test response: %v", err) + } }) wrappedHandler := middleware.Middleware(testHandler) diff --git a/internal/auth/oauth/provider.go b/internal/auth/oauth/provider.go index cce007e..704316b 100644 --- a/internal/auth/oauth/provider.go +++ b/internal/auth/oauth/provider.go @@ -116,7 +116,11 @@ func (p *AzureOAuthProvider) GetAuthorizationServerMetadata(serverURL string) (* log.Printf("OAuth ERROR: Failed to fetch metadata from %s: %v", metadataURL, err) return nil, fmt.Errorf("failed to fetch metadata from %s: %w", metadataURL, err) } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + log.Printf("Failed to close response body: %v", err) + } + }() if resp.StatusCode == http.StatusNotFound { log.Printf("OAuth ERROR: Tenant ID '%s' not found (HTTP 404)", p.config.TenantID) @@ -181,7 +185,7 @@ func (p *AzureOAuthProvider) GetAuthorizationServerMetadata(serverURL string) (* if err == nil { // If the server URL includes /mcp path, include it in the proxy endpoint proxyPath := "/oauth2/v2.0/authorize" - tokenPath := "/oauth2/v2.0/token" + tokenPath := "/oauth2/v2.0/token" // #nosec G101 -- This is an OAuth endpoint path, not credentials registrationPath := "/oauth/register" proxyAuthURL := fmt.Sprintf("%s://%s%s", parsedURL.Scheme, parsedURL.Host, proxyPath) tokenURL := fmt.Sprintf("%s://%s%s", parsedURL.Scheme, parsedURL.Host, tokenPath) @@ -421,7 +425,11 @@ func (p *AzureOAuthProvider) getPublicKey(kid string, issuer string) (*rsa.Publi if err != nil { return nil, fmt.Errorf("failed to fetch JWKS from %s: %w", jwksURL, err) } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + log.Printf("Failed to close response body: %v", err) + } + }() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("JWKS endpoint returned status %d", resp.StatusCode) diff --git a/internal/auth/oauth/provider_test.go b/internal/auth/oauth/provider_test.go index 66e9824..c8722d5 100644 --- a/internal/auth/oauth/provider_test.go +++ b/internal/auth/oauth/provider_test.go @@ -113,7 +113,9 @@ func TestGetAuthorizationServerMetadataWithDefaults(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(mockMetadata) + if err := json.NewEncoder(w).Encode(mockMetadata); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + } })) defer server.Close() @@ -211,7 +213,9 @@ func TestGetAuthorizationServerMetadata(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(mockMetadata) + if err := json.NewEncoder(w).Encode(mockMetadata); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + } })) defer server.Close() diff --git a/internal/auth/types_test.go b/internal/auth/types_test.go index fb84dcc..d6becde 100644 --- a/internal/auth/types_test.go +++ b/internal/auth/types_test.go @@ -142,12 +142,20 @@ func TestOAuthConfigEnvironmentVariables(t *testing.T) { oldClientID := os.Getenv("AZURE_CLIENT_ID") defer func() { - os.Setenv("AZURE_TENANT_ID", oldTenantID) - os.Setenv("AZURE_CLIENT_ID", oldClientID) + if err := os.Setenv("AZURE_TENANT_ID", oldTenantID); err != nil { + t.Logf("Failed to restore AZURE_TENANT_ID: %v", err) + } + if err := os.Setenv("AZURE_CLIENT_ID", oldClientID); err != nil { + t.Logf("Failed to restore AZURE_CLIENT_ID: %v", err) + } }() - os.Setenv("AZURE_TENANT_ID", "env-tenant-id") - os.Setenv("AZURE_CLIENT_ID", "env-client-id") + if err := os.Setenv("AZURE_TENANT_ID", "env-tenant-id"); err != nil { + t.Fatalf("Failed to set AZURE_TENANT_ID: %v", err) + } + if err := os.Setenv("AZURE_CLIENT_ID", "env-client-id"); err != nil { + t.Fatalf("Failed to set AZURE_CLIENT_ID: %v", err) + } config := NewDefaultOAuthConfig() config.Enabled = true From e1dba67a1364976078f42bdd5321666d6949dc3f Mon Sep 17 00:00:00 2001 From: Guoxun Wei Date: Mon, 8 Sep 2025 17:56:54 +0800 Subject: [PATCH 3/4] avoid hard-coded port --- docs/oauth-authentication.md | 23 +---- internal/auth/oauth/endpoints.go | 42 ++++---- internal/auth/oauth/endpoints_test.go | 134 +++++++++----------------- internal/auth/oauth/provider.go | 6 +- internal/server/server.go | 2 +- 5 files changed, 75 insertions(+), 132 deletions(-) diff --git a/docs/oauth-authentication.md b/docs/oauth-authentication.md index 1588656..b62b15f 100644 --- a/docs/oauth-authentication.md +++ b/docs/oauth-authentication.md @@ -103,20 +103,7 @@ AKS-MCP uses the same environment variables (`AZURE_TENANT_ID`, `AZURE_CLIENT_ID ```bash # Create Azure AD application with public client platform az ad app create --display-name "AKS-MCP-OAuth" \ - --public-client-redirect-uris "http://localhost:8000/oauth/callback" "http://localhost:6274/oauth/callback/debug" - -# Get application details -az ad app list --display-name "AKS-MCP-OAuth" --query "[0].{appId:appId,objectId:objectId}" - -# Get your tenant ID -az account show --query "tenantId" -o tsv -``` - -**For Single-page application platform:** -```bash -# Create Azure AD application with SPA platform -az ad app create --display-name "AKS-MCP-OAuth" \ - --spa-redirect-uris "http://localhost:8000/oauth/callback" "http://localhost:6274/oauth/callback/debug" + --public-client-redirect-uris "http://localhost:8000/oauth/callback" # Get application details az ad app list --display-name "AKS-MCP-OAuth" --query "[0].{appId:appId,objectId:objectId}" @@ -196,8 +183,6 @@ curl -H "Authorization: Bearer YOUR_TOKEN" http://localhost:8000/mcp ./aks-mcp --oauth-enabled --access-level=readonly ``` -**Note for MCP Inspector**: MCP Inspector requires the callback URL `http://localhost:6274/oauth/callback/debug` to be configured in your Azure AD application redirect URIs for proper OAuth testing. - ### Step 4: Start AKS-MCP with OAuth ```bash @@ -208,7 +193,7 @@ curl -H "Authorization: Bearer YOUR_TOKEN" http://localhost:8000/mcp --oauth-enabled \ --oauth-tenant-id="$AZURE_TENANT_ID" \ --oauth-client-id="$AZURE_CLIENT_ID" \ - --oauth-redirects="http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback/debug" \ + --oauth-redirects="http://localhost:8000/oauth/callback" \ --access-level=readonly # Using SSE transport with OAuth (alternative) @@ -218,7 +203,7 @@ curl -H "Authorization: Bearer YOUR_TOKEN" http://localhost:8000/mcp --oauth-enabled \ --oauth-tenant-id="$AZURE_TENANT_ID" \ --oauth-client-id="$AZURE_CLIENT_ID" \ - --oauth-redirects="http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback/debug" \ + --oauth-redirects="http://localhost:8000/oauth/callback" \ --access-level=readonly # Environment variables are automatically used if set @@ -243,7 +228,7 @@ curl -H "Authorization: Bearer YOUR_TOKEN" http://localhost:8000/mcp ./aks-mcp --transport=sse --oauth-enabled=true \ --oauth-tenant-id="12345678-1234-1234-1234-123456789012" \ --oauth-client-id="87654321-4321-4321-4321-210987654321" \ - --oauth-redirects="http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback/debug" + --oauth-redirects="http://localhost:8000/oauth/callback" ``` **Note**: Scopes are automatically set to `https://management.azure.com/.default` and cannot be customized via command line. diff --git a/internal/auth/oauth/endpoints.go b/internal/auth/oauth/endpoints.go index de31085..f7d9c5d 100644 --- a/internal/auth/oauth/endpoints.go +++ b/internal/auth/oauth/endpoints.go @@ -44,14 +44,14 @@ func validateAzureADURL(tokenURL string) error { // EndpointManager manages OAuth-related HTTP endpoints type EndpointManager struct { provider *AzureOAuthProvider - config *auth.OAuthConfig + cfg *config.ConfigData } // NewEndpointManager creates a new OAuth endpoint manager -func NewEndpointManager(provider *AzureOAuthProvider, config *auth.OAuthConfig) *EndpointManager { +func NewEndpointManager(provider *AzureOAuthProvider, cfg *config.ConfigData) *EndpointManager { return &EndpointManager{ provider: provider, - config: config, + cfg: cfg, } } @@ -218,7 +218,7 @@ func (em *EndpointManager) clientRegistrationHandler() http.HandlerFunc { // For Azure AD compatibility, use the configured client ID // In a full RFC 7591 implementation, each registration would get a unique ID // But since Azure AD requires pre-registered client IDs, we return the configured one - clientID := em.config.ClientID + clientID := em.cfg.OAuthConfig.ClientID log.Printf("OAuth DEBUG: Client registration successful - returning client_id: %s", clientID) @@ -330,7 +330,7 @@ func (em *EndpointManager) tokenIntrospectionHandler() http.HandlerFunc { // Return active token response response := map[string]interface{}{ "active": true, - "client_id": em.config.ClientID, + "client_id": em.cfg.OAuthConfig.ClientID, "scope": strings.Join(tokenInfo.Scope, " "), "sub": tokenInfo.Subject, "aud": tokenInfo.Audience, @@ -366,7 +366,7 @@ func (em *EndpointManager) healthHandler() http.HandlerFunc { response := map[string]interface{}{ "status": "healthy", "oauth": map[string]interface{}{ - "enabled": em.config.Enabled, + "enabled": em.cfg.OAuthConfig.Enabled, }, } @@ -508,14 +508,14 @@ func (em *EndpointManager) authorizationProxyHandler() http.HandlerFunc { // Use only server-required scopes for Azure AD compatibility // Azure AD .default scopes cannot be mixed with OpenID Connect scopes // We prioritize Azure Management API access over OpenID Connect user info - finalScopes := em.config.RequiredScopes + finalScopes := em.cfg.OAuthConfig.RequiredScopes finalScopeString := strings.Join(finalScopes, " ") query.Set("scope", finalScopeString) log.Printf("OAuth DEBUG: Setting final scope for Azure AD: %s", finalScopeString) // Build the Azure AD authorization URL - azureAuthURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", em.config.TenantID) + azureAuthURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", em.cfg.OAuthConfig.TenantID) // Create the redirect URL with filtered parameters redirectURL := fmt.Sprintf("%s?%s", azureAuthURL, query.Encode()) @@ -590,11 +590,11 @@ func (em *EndpointManager) callbackHandler() http.HandlerFunc { tokenInfo := &auth.TokenInfo{ AccessToken: tokenResponse.AccessToken, TokenType: "Bearer", - ExpiresAt: time.Now().Add(time.Hour), // Default 1 hour expiration - Scope: em.config.RequiredScopes, // Use configured scopes - Subject: "authenticated_user", // Placeholder - Audience: []string{fmt.Sprintf("https://sts.windows.net/%s/", em.config.TenantID)}, - Issuer: fmt.Sprintf("https://sts.windows.net/%s/", em.config.TenantID), + ExpiresAt: time.Now().Add(time.Hour), // Default 1 hour expiration + Scope: em.cfg.OAuthConfig.RequiredScopes, // Use configured scopes + Subject: "authenticated_user", // Placeholder + Audience: []string{fmt.Sprintf("https://sts.windows.net/%s/", em.cfg.OAuthConfig.TenantID)}, + Issuer: fmt.Sprintf("https://sts.windows.net/%s/", em.cfg.OAuthConfig.TenantID), Claims: make(map[string]interface{}), } @@ -615,7 +615,7 @@ type TokenResponse struct { // exchangeCodeForToken exchanges authorization code for access token func (em *EndpointManager) exchangeCodeForToken(code, state string) (*TokenResponse, error) { // Prepare token exchange request - tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", em.config.TenantID) + tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", em.cfg.OAuthConfig.TenantID) // Validate URL for security if err := validateAzureADURL(tokenURL); err != nil { @@ -623,15 +623,15 @@ func (em *EndpointManager) exchangeCodeForToken(code, state string) (*TokenRespo } // Use default callback redirect URI for token exchange - redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", 8080) // Default for MCP servers + redirectURI := fmt.Sprintf("http://%s:%d/oauth/callback", em.cfg.Host, em.cfg.Port) // Prepare form data data := url.Values{} data.Set("grant_type", "authorization_code") - data.Set("client_id", em.config.ClientID) + data.Set("client_id", em.cfg.OAuthConfig.ClientID) data.Set("code", code) data.Set("redirect_uri", redirectURI) - data.Set("scope", strings.Join(em.config.RequiredScopes, " ")) + data.Set("scope", strings.Join(em.cfg.OAuthConfig.RequiredScopes, " ")) // Note: Azure AD v2.0 doesn't support the 'resource' parameter in token requests // It uses scope-based resource identification instead @@ -781,7 +781,7 @@ func (em *EndpointManager) writeCallbackSuccessResponse(w http.ResponseWriter, t // isValidClientID validates if a client ID is acceptable func (em *EndpointManager) isValidClientID(clientID string) bool { // Accept configured client ID (primary method for Azure AD) - if clientID == em.config.ClientID { + if clientID == em.cfg.OAuthConfig.ClientID { return true } @@ -879,7 +879,7 @@ func (em *EndpointManager) tokenHandler() http.HandlerFunc { requestedScope := r.FormValue("scope") if requestedScope == "" { // Fallback to server required scopes if not provided - requestedScope = strings.Join(em.config.RequiredScopes, " ") + requestedScope = strings.Join(em.cfg.OAuthConfig.RequiredScopes, " ") } log.Printf("OAuth DEBUG: Exchanging authorization code for access token with Azure AD, scope: %s", requestedScope) @@ -910,7 +910,7 @@ func (em *EndpointManager) tokenHandler() http.HandlerFunc { // exchangeCodeForTokenDirect exchanges authorization code for access token directly with Azure AD func (em *EndpointManager) exchangeCodeForTokenDirect(code, redirectURI, codeVerifier, scope string) (*TokenResponse, error) { // Prepare token exchange request to Azure AD - tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", em.config.TenantID) + tokenURL := fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", em.cfg.OAuthConfig.TenantID) // Validate URL for security if err := validateAzureADURL(tokenURL); err != nil { @@ -920,7 +920,7 @@ func (em *EndpointManager) exchangeCodeForTokenDirect(code, redirectURI, codeVer // Prepare form data data := url.Values{} data.Set("grant_type", "authorization_code") - data.Set("client_id", em.config.ClientID) + data.Set("client_id", em.cfg.OAuthConfig.ClientID) data.Set("code", code) data.Set("redirect_uri", redirectURI) data.Set("scope", scope) // Use the scope provided by the client diff --git a/internal/auth/oauth/endpoints_test.go b/internal/auth/oauth/endpoints_test.go index c2aa1f8..f044220 100644 --- a/internal/auth/oauth/endpoints_test.go +++ b/internal/auth/oauth/endpoints_test.go @@ -8,10 +8,15 @@ import ( "testing" "github.com/Azure/aks-mcp/internal/auth" + "github.com/Azure/aks-mcp/internal/config" ) -func TestEndpointManager_RegisterEndpoints(t *testing.T) { - config := &auth.OAuthConfig{ +// createTestConfig creates a test ConfigData with OAuth configuration +func createTestConfig() *config.ConfigData { + cfg := config.NewConfig() + cfg.Host = "127.0.0.1" + cfg.Port = 8000 + cfg.OAuthConfig = &auth.OAuthConfig{ Enabled: true, TenantID: "test-tenant", ClientID: "test-client", @@ -22,9 +27,14 @@ func TestEndpointManager_RegisterEndpoints(t *testing.T) { ExpectedAudience: "https://management.azure.com/", }, } + return cfg +} + +func TestEndpointManager_RegisterEndpoints(t *testing.T) { + cfg := createTestConfig() - provider, _ := NewAzureOAuthProvider(config) - manager := NewEndpointManager(provider, config) + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) mux := http.NewServeMux() manager.RegisterEndpoints(mux) @@ -58,15 +68,10 @@ func TestEndpointManager_RegisterEndpoints(t *testing.T) { } func TestProtectedResourceMetadataEndpoint(t *testing.T) { - config := &auth.OAuthConfig{ - Enabled: true, - TenantID: "test-tenant", - ClientID: "test-client", - RequiredScopes: []string{"https://management.azure.com/.default"}, - } + cfg := createTestConfig() - provider, _ := NewAzureOAuthProvider(config) - manager := NewEndpointManager(provider, config) + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil) w := httptest.NewRecorder() @@ -89,20 +94,15 @@ func TestProtectedResourceMetadataEndpoint(t *testing.T) { } if len(metadata.ScopesSupported) != 1 || metadata.ScopesSupported[0] != "https://management.azure.com/.default" { - t.Errorf("Expected scopes %v, got %v", config.RequiredScopes, metadata.ScopesSupported) + t.Errorf("Expected scopes %v, got %v", cfg.OAuthConfig.RequiredScopes, metadata.ScopesSupported) } } func TestClientRegistrationEndpoint(t *testing.T) { - config := &auth.OAuthConfig{ - Enabled: true, - TenantID: "test-tenant", - ClientID: "test-client", - RequiredScopes: []string{"https://management.azure.com/.default"}, - } + cfg := createTestConfig() - provider, _ := NewAzureOAuthProvider(config) - manager := NewEndpointManager(provider, config) + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) // Test valid registration request registrationRequest := map[string]interface{}{ @@ -142,19 +142,10 @@ func TestClientRegistrationEndpoint(t *testing.T) { } func TestTokenIntrospectionEndpoint(t *testing.T) { - config := &auth.OAuthConfig{ - Enabled: true, - TenantID: "test-tenant", - ClientID: "test-client", - RequiredScopes: []string{"https://management.azure.com/.default"}, - TokenValidation: auth.TokenValidationConfig{ - ValidateJWT: false, - ValidateAudience: false, - }, - } + cfg := createTestConfig() - provider, _ := NewAzureOAuthProvider(config) - manager := NewEndpointManager(provider, config) + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) // Test with valid token (since JWT validation is disabled, any token works) // Note: Must use a token that looks like a JWT (has dots) to pass initial format checks @@ -180,15 +171,10 @@ func TestTokenIntrospectionEndpoint(t *testing.T) { } func TestTokenIntrospectionEndpointMissingToken(t *testing.T) { - config := &auth.OAuthConfig{ - Enabled: true, - TenantID: "test-tenant", - ClientID: "test-client", - RequiredScopes: []string{"https://management.azure.com/.default"}, - } + cfg := createTestConfig() - provider, _ := NewAzureOAuthProvider(config) - manager := NewEndpointManager(provider, config) + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) // Test without token parameter req := httptest.NewRequest("POST", "/oauth/introspect", strings.NewReader("")) @@ -204,15 +190,10 @@ func TestTokenIntrospectionEndpointMissingToken(t *testing.T) { } func TestHealthEndpoint(t *testing.T) { - config := &auth.OAuthConfig{ - Enabled: true, - TenantID: "test-tenant", - ClientID: "test-client", - RequiredScopes: []string{"https://management.azure.com/.default"}, - } + cfg := createTestConfig() - provider, _ := NewAzureOAuthProvider(config) - manager := NewEndpointManager(provider, config) + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) req := httptest.NewRequest("GET", "/health", nil) w := httptest.NewRecorder() @@ -244,15 +225,10 @@ func TestHealthEndpoint(t *testing.T) { } func TestValidateClientRegistration(t *testing.T) { - config := &auth.OAuthConfig{ - Enabled: true, - TenantID: "test-tenant", - ClientID: "test-client", - RequiredScopes: []string{"https://management.azure.com/.default"}, - } + cfg := createTestConfig() - provider, _ := NewAzureOAuthProvider(config) - manager := NewEndpointManager(provider, config) + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) tests := []struct { name string @@ -320,15 +296,10 @@ func TestValidateClientRegistration(t *testing.T) { } func TestCallbackEndpointMissingCode(t *testing.T) { - config := &auth.OAuthConfig{ - Enabled: true, - TenantID: "test-tenant", - ClientID: "test-client", - RequiredScopes: []string{"https://management.azure.com/.default"}, - } + cfg := createTestConfig() - provider, _ := NewAzureOAuthProvider(config) - manager := NewEndpointManager(provider, config) + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) // Test callback without authorization code req := httptest.NewRequest("GET", "/oauth/callback?state=test-state", nil) @@ -354,15 +325,10 @@ func TestCallbackEndpointMissingCode(t *testing.T) { } func TestCallbackEndpointMissingState(t *testing.T) { - config := &auth.OAuthConfig{ - Enabled: true, - TenantID: "test-tenant", - ClientID: "test-client", - RequiredScopes: []string{"https://management.azure.com/.default"}, - } + cfg := createTestConfig() - provider, _ := NewAzureOAuthProvider(config) - manager := NewEndpointManager(provider, config) + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) // Test callback without state parameter req := httptest.NewRequest("GET", "/oauth/callback?code=test-code", nil) @@ -382,15 +348,10 @@ func TestCallbackEndpointMissingState(t *testing.T) { } func TestCallbackEndpointAuthError(t *testing.T) { - config := &auth.OAuthConfig{ - Enabled: true, - TenantID: "test-tenant", - ClientID: "test-client", - RequiredScopes: []string{"https://management.azure.com/.default"}, - } + cfg := createTestConfig() - provider, _ := NewAzureOAuthProvider(config) - manager := NewEndpointManager(provider, config) + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) // Test callback with authorization error req := httptest.NewRequest("GET", "/oauth/callback?error=access_denied&error_description=User%20denied%20access", nil) @@ -413,15 +374,10 @@ func TestCallbackEndpointAuthError(t *testing.T) { } func TestCallbackEndpointMethodNotAllowed(t *testing.T) { - config := &auth.OAuthConfig{ - Enabled: true, - TenantID: "test-tenant", - ClientID: "test-client", - RequiredScopes: []string{"https://management.azure.com/.default"}, - } + cfg := createTestConfig() - provider, _ := NewAzureOAuthProvider(config) - manager := NewEndpointManager(provider, config) + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) // Test callback with POST method (should only accept GET) req := httptest.NewRequest("POST", "/oauth/callback", nil) diff --git a/internal/auth/oauth/provider.go b/internal/auth/oauth/provider.go index 704316b..0ef4848 100644 --- a/internal/auth/oauth/provider.go +++ b/internal/auth/oauth/provider.go @@ -205,9 +205,11 @@ func (p *AzureOAuthProvider) GetAuthorizationServerMetadata(serverURL string) (* // ValidateToken validates an OAuth access token func (p *AzureOAuthProvider) ValidateToken(ctx context.Context, tokenString string) (*auth.TokenInfo, error) { - // Check if token looks like a valid JWT (should have exactly 2 dots) + // JWTs have three parts (header.payload.signature) separated by two dots. + const jwtExpectedDotCount = 2 + dotCount := strings.Count(tokenString, ".") - if dotCount != 2 { + if dotCount != jwtExpectedDotCount { return nil, fmt.Errorf("invalid JWT token format: expected 3 parts separated by dots, got %d dots", dotCount) } diff --git a/internal/server/server.go b/internal/server/server.go index fb3c485..39b4513 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -149,7 +149,7 @@ func (s *Service) initializeOAuth() error { s.authMiddleware = oauth.NewAuthMiddleware(provider, serverURL) // Create endpoint manager - s.endpointManager = oauth.NewEndpointManager(provider, s.cfg.OAuthConfig) + s.endpointManager = oauth.NewEndpointManager(provider, s.cfg) log.Printf("OAuth authentication initialized with tenant: %s", s.cfg.OAuthConfig.TenantID) return nil From a4a241b18867fb4df5cab6a9830553c0e5b77df9 Mon Sep 17 00:00:00 2001 From: Guoxun Wei Date: Tue, 9 Sep 2025 10:52:45 +0800 Subject: [PATCH 4/4] address comments --- docs/oauth-authentication.md | 39 ++++- internal/auth/oauth/endpoints.go | 105 +++++++++--- internal/auth/oauth/endpoints_test.go | 209 +++++++++++++++++++++++ internal/auth/oauth/middleware.go | 34 ++-- internal/auth/oauth/provider.go | 5 +- internal/auth/oauth/provider_test.go | 3 + internal/auth/types.go | 14 +- internal/config/config.go | 105 +++++++++++- internal/config/config_test.go | 228 +++++++++++++++++++++++++- 9 files changed, 688 insertions(+), 54 deletions(-) diff --git a/docs/oauth-authentication.md b/docs/oauth-authentication.md index b62b15f..344279f 100644 --- a/docs/oauth-authentication.md +++ b/docs/oauth-authentication.md @@ -219,6 +219,7 @@ curl -H "Authorization: Bearer YOUR_TOKEN" http://localhost:8000/mcp - `--oauth-tenant-id`: Azure AD tenant ID (or use AZURE_TENANT_ID env var) - `--oauth-client-id`: Azure AD client ID (or use AZURE_CLIENT_ID env var) - `--oauth-redirects`: Comma-separated list of allowed redirect URIs (required when OAuth enabled) +- `--oauth-cors-origins`: Comma-separated list of allowed CORS origins for OAuth endpoints (e.g. http://localhost:6274 for MCP Inspector). If empty, no cross-origin requests are allowed for security **Note**: OAuth scopes are automatically configured to use `https://management.azure.com/.default` for optimal Azure AD compatibility. Custom scopes are not currently configurable via command line. @@ -227,8 +228,7 @@ curl -H "Authorization: Bearer YOUR_TOKEN" http://localhost:8000/mcp ```bash ./aks-mcp --transport=sse --oauth-enabled=true \ --oauth-tenant-id="12345678-1234-1234-1234-123456789012" \ - --oauth-client-id="87654321-4321-4321-4321-210987654321" \ - --oauth-redirects="http://localhost:8000/oauth/callback" + --oauth-client-id="87654321-4321-4321-4321-210987654321" ``` **Note**: Scopes are automatically set to `https://management.azure.com/.default` and cannot be customized via command line. @@ -436,4 +436,39 @@ To migrate from a non-OAuth AKS-MCP deployment: The MCP Inspector tool can be used to test OAuth-enabled AKS-MCP servers. Configure the Inspector's OAuth settings to match your AKS-MCP OAuth configuration for testing. +### Important: Redirect URI Configuration for MCP Inspector + +When using MCP Inspector with OAuth authentication, you need to add the Inspector's proxy redirect URI to your OAuth configuration: + +```bash +# Add Inspector's redirect URI (typically http://localhost:6274/oauth/callback) +./aks-mcp \ + --transport=streamable-http \ + --port=8000 \ + --oauth-enabled \ + --oauth-redirects="http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback" \ + --access-level=readonly +``` + +**Key Points:** +- MCP Inspector typically runs on port 6274 by default +- The Inspector creates a proxy redirect URI at `/oauth/callback` +- You must include both your server's redirect URI AND the Inspector's redirect URI +- You must also configure CORS origins to allow the Inspector's web interface to make requests +- Comma-separate multiple redirect URIs in the `--oauth-redirects` parameter +- Comma-separate multiple CORS origins in the `--oauth-cors-origins` parameter +- Without the Inspector's redirect URI, OAuth authentication will fail with "redirect_uri not registered" error +- Without the Inspector's CORS origin, the web interface will be blocked by browser CORS policy + +**Example with MCP Inspector configuration:** +```bash +./aks-mcp \ + --transport=streamable-http \ + --port=8000 \ + --oauth-enabled \ + --oauth-redirects="http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback" \ + --oauth-cors-origins="http://localhost:6274" \ + --access-level=readonly +``` + For more information, see the MCP OAuth specification and Azure AD documentation. \ No newline at end of file diff --git a/internal/auth/oauth/endpoints.go b/internal/auth/oauth/endpoints.go index f7d9c5d..af49899 100644 --- a/internal/auth/oauth/endpoints.go +++ b/internal/auth/oauth/endpoints.go @@ -55,15 +55,29 @@ func NewEndpointManager(provider *AzureOAuthProvider, cfg *config.ConfigData) *E } } -// setCORSHeaders sets CORS headers for OAuth endpoints to allow MCP Inspector access -func (em *EndpointManager) setCORSHeaders(w http.ResponseWriter) { - origin := "*" // TODO: Restrict to specific origins - - w.Header().Set("Access-Control-Allow-Origin", origin) - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, mcp-protocol-version") - w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours - w.Header().Set("Access-Control-Allow-Credentials", "false") // Explicit false for wildcard origin +// setCORSHeaders sets CORS headers for OAuth endpoints with origin whitelisting +func (em *EndpointManager) setCORSHeaders(w http.ResponseWriter, r *http.Request) { + requestOrigin := r.Header.Get("Origin") + + // Check if the request origin is in the allowed list + var allowedOrigin string + for _, allowed := range em.provider.config.AllowedOrigins { + if requestOrigin == allowed { + allowedOrigin = requestOrigin + break + } + } + + // Only set CORS headers if origin is allowed + if allowedOrigin != "" { + w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, mcp-protocol-version") + w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours + w.Header().Set("Access-Control-Allow-Credentials", "false") + } else if requestOrigin != "" { + log.Printf("CORS ERROR: Origin %s is not in the allowed list - cross-origin requests will be blocked for security", requestOrigin) + } } // setCacheHeaders sets cache control headers based on EnableCache configuration @@ -116,7 +130,7 @@ func (em *EndpointManager) authServerMetadataProxyHandler() http.HandlerFunc { log.Printf("OAuth DEBUG: Received request for authorization server metadata: %s %s", r.Method, r.URL.Path) // Set CORS headers for all requests - em.setCORSHeaders(w) + em.setCORSHeaders(w, r) // Handle preflight OPTIONS request if r.Method == http.MethodOptions { @@ -171,7 +185,7 @@ func (em *EndpointManager) clientRegistrationHandler() http.HandlerFunc { log.Printf("OAuth DEBUG: Received client registration request: %s %s", r.Method, r.URL.Path) // Set CORS headers for all requests - em.setCORSHeaders(w) + em.setCORSHeaders(w, r) // Handle preflight OPTIONS request if r.Method == http.MethodOptions { @@ -220,8 +234,6 @@ func (em *EndpointManager) clientRegistrationHandler() http.HandlerFunc { // But since Azure AD requires pre-registered client IDs, we return the configured one clientID := em.cfg.OAuthConfig.ClientID - log.Printf("OAuth DEBUG: Client registration successful - returning client_id: %s", clientID) - clientInfo := map[string]interface{}{ "client_id": clientID, // Use configured Azure AD client ID "client_id_issued_at": time.Now().Unix(), // RFC 7591: timestamp of issuance @@ -284,11 +296,28 @@ func (em *EndpointManager) validateClientRegistration(req *ClientRegistrationReq return nil } +// validateRedirectURI validates that a redirect URI is registered and allowed +func (em *EndpointManager) validateRedirectURI(redirectURI string) error { + if len(em.cfg.OAuthConfig.RedirectURIs) == 0 { + return fmt.Errorf("no redirect URIs configured") + } + + for _, allowed := range em.cfg.OAuthConfig.RedirectURIs { + if redirectURI == allowed { + return nil + } + } + + log.Printf("OAuth SECURITY WARNING: Invalid redirect URI attempted: %s, allowed: %v", + redirectURI, em.cfg.OAuthConfig.RedirectURIs) + return fmt.Errorf("redirect_uri not registered: %s", redirectURI) +} + // tokenIntrospectionHandler implements RFC 7662 OAuth 2.0 Token Introspection func (em *EndpointManager) tokenIntrospectionHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Set CORS headers for all requests - em.setCORSHeaders(w) + em.setCORSHeaders(w, r) // Handle preflight OPTIONS request if r.Method == http.MethodOptions { @@ -350,7 +379,7 @@ func (em *EndpointManager) tokenIntrospectionHandler() http.HandlerFunc { func (em *EndpointManager) healthHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Set CORS headers for all requests - em.setCORSHeaders(w) + em.setCORSHeaders(w, r) // Handle preflight OPTIONS request if r.Method == http.MethodOptions { @@ -384,7 +413,7 @@ func (em *EndpointManager) protectedResourceMetadataHandler() http.HandlerFunc { log.Printf("OAuth DEBUG: Received request for protected resource metadata: %s %s", r.Method, r.URL.Path) // Set CORS headers for all requests - em.setCORSHeaders(w) + em.setCORSHeaders(w, r) // Handle preflight OPTIONS request if r.Method == http.MethodOptions { @@ -457,7 +486,7 @@ func (em *EndpointManager) authorizationProxyHandler() http.HandlerFunc { log.Printf("OAuth DEBUG: Received authorization proxy request: %s %s", r.Method, r.URL.Path) // Set CORS headers for all requests - em.setCORSHeaders(w) + em.setCORSHeaders(w, r) // Handle preflight OPTIONS request if r.Method == http.MethodOptions { @@ -474,6 +503,23 @@ func (em *EndpointManager) authorizationProxyHandler() http.HandlerFunc { // Parse query parameters query := r.URL.Query() + // Validate redirect_uri parameter for security and better user experience + redirectURI := query.Get("redirect_uri") + if redirectURI == "" { + log.Printf("OAuth ERROR: Missing redirect_uri parameter in authorization request") + log.Printf("OAuth HELP: To fix this error, configure redirect URIs using --oauth-redirects flag") + log.Printf("OAuth HELP: For MCP Inspector, use: --oauth-redirects=\"http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback\"") + em.writeErrorResponse(w, "invalid_request", "redirect_uri parameter is required", http.StatusBadRequest) + return + } + + // Validate that the redirect_uri is registered and allowed + if err := em.validateRedirectURI(redirectURI); err != nil { + log.Printf("OAuth ERROR: redirect_uri %s not registered - requests will be blocked for security", redirectURI) + em.writeErrorResponse(w, "invalid_request", fmt.Sprintf("redirect_uri not registered: %s", redirectURI), http.StatusBadRequest) + return + } + // Enforce PKCE for OAuth 2.1 compliance (MCP requirement) codeChallenge := query.Get("code_challenge") codeChallengeMethod := query.Get("code_challenge_method") @@ -532,7 +578,7 @@ func (em *EndpointManager) callbackHandler() http.HandlerFunc { log.Printf("OAuth DEBUG: Received callback request: %s %s", r.Method, r.URL.Path) // Set CORS headers for all requests - em.setCORSHeaders(w) + em.setCORSHeaders(w, r) // Handle preflight OPTIONS request if r.Method == http.MethodOptions { @@ -575,6 +621,14 @@ func (em *EndpointManager) callbackHandler() http.HandlerFunc { log.Printf("OAuth DEBUG: Callback parameters validated - has_code: true, state: %s", state) + // Validate redirect URI for security - construct expected URI and validate it + expectedRedirectURI := fmt.Sprintf("http://%s:%d/oauth/callback", em.cfg.Host, em.cfg.Port) + if err := em.validateRedirectURI(expectedRedirectURI); err != nil { + log.Printf("OAuth ERROR: Redirect URI validation failed: %v", err) + em.writeCallbackErrorResponse(w, "Invalid redirect URI") + return + } + // Exchange authorization code for access token tokenResponse, err := em.exchangeCodeForToken(code, state) if err != nil { @@ -583,8 +637,6 @@ func (em *EndpointManager) callbackHandler() http.HandlerFunc { return } - log.Printf("OAuth DEBUG: Token exchange successful, preparing callback success response") - // Skip token validation in callback - validation happens during MCP requests // Create minimal token info for callback success page tokenInfo := &auth.TokenInfo{ @@ -806,7 +858,7 @@ func (em *EndpointManager) tokenHandler() http.HandlerFunc { log.Printf("OAuth DEBUG: Received token endpoint request: %s %s", r.Method, r.URL.Path) // Set CORS headers for all requests - em.setCORSHeaders(w) + em.setCORSHeaders(w, r) // Handle preflight OPTIONS request if r.Method == http.MethodOptions { @@ -866,8 +918,6 @@ func (em *EndpointManager) tokenHandler() http.HandlerFunc { return } - log.Printf("OAuth DEBUG: Token request parameters validated - client_id: %s, redirect_uri: %s, has_code_verifier: %t", clientID, redirectURI, codeVerifier != "") - // Validate client ID (accept both configured and dynamically registered clients) if !em.isValidClientID(clientID) { log.Printf("OAuth ERROR: Invalid client_id: %s", clientID) @@ -875,6 +925,13 @@ func (em *EndpointManager) tokenHandler() http.HandlerFunc { return } + // Validate redirect URI for security + if err := em.validateRedirectURI(redirectURI); err != nil { + log.Printf("OAuth ERROR: Redirect URI validation failed in token endpoint: %v", err) + em.writeErrorResponse(w, "invalid_request", "Invalid redirect_uri", http.StatusBadRequest) + return + } + // Extract scope from the token request (MCP client should send the same scope) requestedScope := r.FormValue("scope") if requestedScope == "" { @@ -892,8 +949,6 @@ func (em *EndpointManager) tokenHandler() http.HandlerFunc { return } - log.Printf("OAuth DEBUG: Token exchange successful, returning token response to client") - // Return token response w.Header().Set("Content-Type", "application/json") w.Header().Set("Cache-Control", "no-store") diff --git a/internal/auth/oauth/endpoints_test.go b/internal/auth/oauth/endpoints_test.go index f044220..f457d5c 100644 --- a/internal/auth/oauth/endpoints_test.go +++ b/internal/auth/oauth/endpoints_test.go @@ -21,6 +21,7 @@ func createTestConfig() *config.ConfigData { TenantID: "test-tenant", ClientID: "test-client", RequiredScopes: []string{"https://management.azure.com/.default"}, + RedirectURIs: []string{"http://127.0.0.1:8000/oauth/callback", "http://localhost:8000/oauth/callback"}, TokenValidation: auth.TokenValidationConfig{ ValidateJWT: false, ValidateAudience: false, @@ -390,3 +391,211 @@ func TestCallbackEndpointMethodNotAllowed(t *testing.T) { t.Errorf("Expected status 405 for POST method, got %d", w.Code) } } + +func TestValidateRedirectURI(t *testing.T) { + cfg := createTestConfig() + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + tests := []struct { + name string + redirectURI string + wantErr bool + }{ + { + name: "valid redirect URI - 127.0.0.1", + redirectURI: "http://127.0.0.1:8000/oauth/callback", + wantErr: false, + }, + { + name: "valid redirect URI - localhost", + redirectURI: "http://localhost:8000/oauth/callback", + wantErr: false, + }, + { + name: "invalid redirect URI - wrong port", + redirectURI: "http://127.0.0.1:9000/oauth/callback", + wantErr: true, + }, + { + name: "invalid redirect URI - wrong path", + redirectURI: "http://127.0.0.1:8000/oauth/malicious", + wantErr: true, + }, + { + name: "invalid redirect URI - external domain", + redirectURI: "http://malicious.com:8000/oauth/callback", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := manager.validateRedirectURI(tt.redirectURI) + if (err != nil) != tt.wantErr { + t.Errorf("validateRedirectURI() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } + + // Test with empty redirect URIs configuration + cfgEmpty := createTestConfig() + cfgEmpty.OAuthConfig.RedirectURIs = []string{} + managerEmpty := NewEndpointManager(provider, cfgEmpty) + + err := managerEmpty.validateRedirectURI("http://127.0.0.1:8000/oauth/callback") + if err == nil { + t.Error("Expected error when no redirect URIs are configured") + } +} + +// TestAuthorizationProxyRedirectURIValidation tests the authorization endpoint redirect URI validation +func TestCORSHeaders(t *testing.T) { + cfg := createTestConfig() + cfg.OAuthConfig.AllowedOrigins = []string{"http://localhost:6274"} + + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + tests := []struct { + name string + origin string + expectCORSSet bool + expectOrigin string + }{ + { + name: "allowed origin", + origin: "http://localhost:6274", + expectCORSSet: true, + expectOrigin: "http://localhost:6274", + }, + { + name: "disallowed origin", + origin: "http://malicious.com", + expectCORSSet: false, + expectOrigin: "", + }, + { + name: "no origin header", + origin: "", + expectCORSSet: false, + expectOrigin: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/health", nil) + if tt.origin != "" { + req.Header.Set("Origin", tt.origin) + } + w := httptest.NewRecorder() + + handler := manager.healthHandler() + handler(w, req) + + corsOrigin := w.Header().Get("Access-Control-Allow-Origin") + if tt.expectCORSSet { + if corsOrigin != tt.expectOrigin { + t.Errorf("Expected CORS origin %s, got %s", tt.expectOrigin, corsOrigin) + } + } else { + if corsOrigin != "" { + t.Errorf("Expected no CORS headers, but got Access-Control-Allow-Origin: %s", corsOrigin) + } + } + }) + } +} + +func TestAuthorizationProxyRedirectURIValidation(t *testing.T) { + cfg := createTestConfig() + provider, _ := NewAzureOAuthProvider(cfg.OAuthConfig) + manager := NewEndpointManager(provider, cfg) + + tests := []struct { + name string + redirectURI string + expectError bool + expectCode int + }{ + { + name: "missing redirect_uri", + redirectURI: "", + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "valid redirect_uri - 127.0.0.1", + redirectURI: "http://127.0.0.1:8000/oauth/callback", + expectError: false, + expectCode: http.StatusFound, // Should redirect to Azure AD + }, + { + name: "valid redirect_uri - localhost", + redirectURI: "http://localhost:8000/oauth/callback", + expectError: false, + expectCode: http.StatusFound, // Should redirect to Azure AD + }, + { + name: "invalid redirect_uri - wrong port", + redirectURI: "http://127.0.0.1:9000/oauth/callback", + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "invalid redirect_uri - wrong path", + redirectURI: "http://127.0.0.1:8000/oauth/malicious", + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "invalid redirect_uri - external domain", + redirectURI: "http://malicious.com:8000/oauth/callback", + expectError: true, + expectCode: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create request URL with redirect_uri parameter if provided + requestURL := "/oauth2/v2.0/authorize?response_type=code&client_id=test-client&code_challenge=test&code_challenge_method=S256&state=test" + if tt.redirectURI != "" { + requestURL += "&redirect_uri=" + tt.redirectURI + } + + req := httptest.NewRequest("GET", requestURL, nil) + w := httptest.NewRecorder() + + handler := manager.authorizationProxyHandler() + handler(w, req) + + if tt.expectError { + if w.Code != tt.expectCode { + t.Errorf("Expected status code %d, got %d", tt.expectCode, w.Code) + } + + // Check that error response contains helpful information + body := w.Body.String() + if !strings.Contains(body, "redirect_uri") { + t.Errorf("Error response should mention redirect_uri, got: %s", body) + } + } else { + if w.Code != tt.expectCode { + t.Errorf("Expected status code %d, got %d", tt.expectCode, w.Code) + } + + // For successful cases, check redirect location contains expected parameters + location := w.Header().Get("Location") + if location == "" { + t.Errorf("Expected redirect location header, got empty") + } + if !strings.Contains(location, "login.microsoftonline.com") { + t.Errorf("Expected redirect to Azure AD, got: %s", location) + } + } + }) + } +} diff --git a/internal/auth/oauth/middleware.go b/internal/auth/oauth/middleware.go index f951fef..5170385 100644 --- a/internal/auth/oauth/middleware.go +++ b/internal/auth/oauth/middleware.go @@ -22,15 +22,29 @@ type AuthMiddleware struct { serverURL string } -// setCORSHeaders sets CORS headers for OAuth endpoints to allow MCP Inspector access -func (m *AuthMiddleware) setCORSHeaders(w http.ResponseWriter) { - origin := "*" // TODO: Restrict to specific origins - - w.Header().Set("Access-Control-Allow-Origin", origin) - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, mcp-protocol-version") - w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours - w.Header().Set("Access-Control-Allow-Credentials", "false") // Explicit false for wildcard origin +// setCORSHeaders sets CORS headers for OAuth endpoints with origin whitelisting +func (m *AuthMiddleware) setCORSHeaders(w http.ResponseWriter, r *http.Request) { + requestOrigin := r.Header.Get("Origin") + + // Check if the request origin is in the allowed list + var allowedOrigin string + for _, allowed := range m.provider.config.AllowedOrigins { + if requestOrigin == allowed { + allowedOrigin = requestOrigin + break + } + } + + // Only set CORS headers if origin is allowed + if allowedOrigin != "" { + w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, mcp-protocol-version") + w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours + w.Header().Set("Access-Control-Allow-Credentials", "false") + } else if requestOrigin != "" { + log.Printf("CORS ERROR: Origin %s is not in the allowed list - cross-origin requests will be blocked for security", requestOrigin) + } } // NewAuthMiddleware creates a new authentication middleware @@ -228,7 +242,7 @@ func (m *AuthMiddleware) hasScopePermission(requiredScope string, tokenScopes [] // handleAuthError handles authentication errors func (m *AuthMiddleware) handleAuthError(w http.ResponseWriter, r *http.Request, authResult *auth.AuthResult) { // Set CORS headers - m.setCORSHeaders(w) + m.setCORSHeaders(w, r) w.Header().Set("Content-Type", "application/json") // Add WWW-Authenticate header for 401 responses (RFC 9728 Section 5.1) diff --git a/internal/auth/oauth/provider.go b/internal/auth/oauth/provider.go index 0ef4848..d6ec9bc 100644 --- a/internal/auth/oauth/provider.go +++ b/internal/auth/oauth/provider.go @@ -213,8 +213,11 @@ func (p *AzureOAuthProvider) ValidateToken(ctx context.Context, tokenString stri return nil, fmt.Errorf("invalid JWT token format: expected 3 parts separated by dots, got %d dots", dotCount) } - // If JWT validation is disabled, return a minimal token info without full validation + // SECURITY WARNING: JWT validation bypass - for development and testing ONLY + // ValidateJWT should ALWAYS be true in production environments + // This bypass creates a significant security vulnerability if enabled in production if !p.config.TokenValidation.ValidateJWT { + log.Printf("WARNING: JWT validation is DISABLED - this should ONLY be used in development/testing") return &auth.TokenInfo{ AccessToken: tokenString, TokenType: "Bearer", diff --git a/internal/auth/oauth/provider_test.go b/internal/auth/oauth/provider_test.go index c8722d5..657f2ba 100644 --- a/internal/auth/oauth/provider_test.go +++ b/internal/auth/oauth/provider_test.go @@ -267,6 +267,9 @@ func TestGetAuthorizationServerMetadata(t *testing.T) { } func TestValidateTokenWithoutJWT(t *testing.T) { + // SECURITY WARNING: This test verifies the JWT validation bypass functionality + // ValidateJWT=false should ONLY be used in development/testing environments + // This functionality should NEVER be enabled in production config := &auth.OAuthConfig{ Enabled: true, TenantID: "test-tenant", diff --git a/internal/auth/types.go b/internal/auth/types.go index f882a95..d89cca6 100644 --- a/internal/auth/types.go +++ b/internal/auth/types.go @@ -19,13 +19,21 @@ type OAuthConfig struct { // Required OAuth scopes for accessing AKS-MCP RequiredScopes []string `json:"required_scopes"` + // Allowed redirect URIs for OAuth callback + RedirectURIs []string `json:"redirect_uris"` + + // Allowed CORS origins for OAuth endpoints (for security, wildcard "*" should be avoided) + AllowedOrigins []string `json:"allowed_origins"` + // Token validation settings TokenValidation TokenValidationConfig `json:"token_validation"` } // TokenValidationConfig represents token validation configuration type TokenValidationConfig struct { - // Enable JWT token validation + // SECURITY CRITICAL: Enable JWT token validation + // Setting this to false creates a security vulnerability - for development/testing ONLY + // MUST be true in production environments ValidateJWT bool `json:"validate_jwt"` // Enable audience validation @@ -98,8 +106,10 @@ func NewDefaultOAuthConfig() *OAuthConfig { // Use Azure Management API scope to get v2.0 format tokens // This ensures we get v2.0 issuer format which works with v2.0 JWKS endpoints RequiredScopes: []string{AzureADScope}, // "https://management.azure.com/.default" + // RedirectURIs will be populated dynamically based on host/port configuration + RedirectURIs: []string{}, TokenValidation: TokenValidationConfig{ - ValidateJWT: true, // Re-enabled now that we'll get v2.0 tokens + ValidateJWT: true, // SECURITY CRITICAL: Always true in production ValidateAudience: true, // Re-enabled with correct audience ExpectedAudience: DefaultExpectedAudience, // "https://management.azure.com" CacheTTL: DefaultTokenCacheTTL, diff --git a/internal/config/config.go b/internal/config/config.go index 55c175a..b60f068 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "os" + "regexp" "strings" "time" @@ -16,9 +17,24 @@ import ( ) // EnableCache controls whether caching is enabled globally -// Set to false during debugging to avoid cache-related issues +// Cache is enabled by default for production performance // This affects both web cache headers and AzureOAuthProvider cache -const EnableCache = false +// Can be disabled via DISABLE_CACHE environment variable +var EnableCache = os.Getenv("DISABLE_CACHE") != "true" + +// validateGUID validates that a value is in valid GUID format +func validateGUID(value, name string) error { + if value == "" { + return nil // Empty values are allowed (will be handled by OAuth validation) + } + + // GUID pattern: 8-4-4-4-12 hexadecimal digits with hyphens + guidRegex := regexp.MustCompile(`^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$`) + if !guidRegex.MatchString(value) { + return fmt.Errorf("%s must be a valid GUID format (e.g., 12345678-1234-1234-1234-123456789abc), got: %s", name, value) + } + return nil +} // ConfigData holds the global configuration type ConfigData struct { @@ -84,6 +100,14 @@ func (cfg *ConfigData) ParseFlags() { flag.StringVar(&cfg.OAuthConfig.TenantID, "oauth-tenant-id", "", "Azure AD tenant ID for OAuth (fallback to AZURE_TENANT_ID env var)") flag.StringVar(&cfg.OAuthConfig.ClientID, "oauth-client-id", "", "Azure AD client ID for OAuth (fallback to AZURE_CLIENT_ID env var)") + // OAuth redirect URIs configuration + additionalRedirectURIs := flag.String("oauth-redirects", "", + "Comma-separated list of additional OAuth redirect URIs (e.g. http://localhost:8000/oauth/callback,http://localhost:6274/oauth/callback)") + + // OAuth CORS origins configuration + allowedCORSOrigins := flag.String("oauth-cors-origins", "", + "Comma-separated list of allowed CORS origins for OAuth endpoints (e.g. http://localhost:6274). If empty, no cross-origin requests are allowed for security") + // Kubernetes-specific settings additionalTools := flag.String("additional-tools", "", "Comma-separated list of additional Kubernetes tools to support (kubectl is always enabled). Available: helm,cilium,hubble") @@ -129,7 +153,10 @@ func (cfg *ConfigData) ParseFlags() { cfg.SecurityConfig.AllowedNamespaces = cfg.AllowNamespaces // Parse OAuth configuration - cfg.parseOAuthConfig() + if err := cfg.parseOAuthConfig(*additionalRedirectURIs, *allowedCORSOrigins); err != nil { + fmt.Printf("OAuth configuration error: %v\n", err) + os.Exit(1) + } // Parse additional tools if *additionalTools != "" { @@ -141,17 +168,83 @@ func (cfg *ConfigData) ParseFlags() { } // parseOAuthConfig parses OAuth-related command line arguments -func (cfg *ConfigData) parseOAuthConfig() { +func (cfg *ConfigData) parseOAuthConfig(additionalRedirectURIs, allowedCORSOrigins string) error { // Note: OAuth scopes are automatically configured to use "https://management.azure.com/.default" // and are not configurable via command line per design + // Track configuration sources for logging + var tenantIDSource, clientIDSource string + // Load OAuth configuration from environment variables if not set via CLI if cfg.OAuthConfig.TenantID == "" { - cfg.OAuthConfig.TenantID = os.Getenv("AZURE_TENANT_ID") + if tenantID := os.Getenv("AZURE_TENANT_ID"); tenantID != "" { + cfg.OAuthConfig.TenantID = tenantID + tenantIDSource = "environment variable AZURE_TENANT_ID" + log.Printf("OAuth Config: Using tenant ID from environment variable AZURE_TENANT_ID") + } + } else { + tenantIDSource = "command line flag --oauth-tenant-id" + log.Printf("OAuth Config: Using tenant ID from command line flag --oauth-tenant-id") } + if cfg.OAuthConfig.ClientID == "" { - cfg.OAuthConfig.ClientID = os.Getenv("AZURE_CLIENT_ID") + if clientID := os.Getenv("AZURE_CLIENT_ID"); clientID != "" { + cfg.OAuthConfig.ClientID = clientID + clientIDSource = "environment variable AZURE_CLIENT_ID" + log.Printf("OAuth Config: Using client ID from environment variable AZURE_CLIENT_ID") + } + } else { + clientIDSource = "command line flag --oauth-client-id" + log.Printf("OAuth Config: Using client ID from command line flag --oauth-client-id") } + + // Validate GUID formats for tenant ID and client ID + if err := validateGUID(cfg.OAuthConfig.TenantID, "OAuth tenant ID"); err != nil { + return fmt.Errorf("invalid OAuth tenant ID from %s: %w", tenantIDSource, err) + } + + if err := validateGUID(cfg.OAuthConfig.ClientID, "OAuth client ID"); err != nil { + return fmt.Errorf("invalid OAuth client ID from %s: %w", clientIDSource, err) + } + + // Set redirect URIs based on configured host and port + if cfg.OAuthConfig.Enabled { + redirectURI := fmt.Sprintf("http://%s:%d/oauth/callback", cfg.Host, cfg.Port) + cfg.OAuthConfig.RedirectURIs = []string{redirectURI} + + // Add localhost variant if using 127.0.0.1 + if cfg.Host == "127.0.0.1" { + localhostURI := fmt.Sprintf("http://localhost:%d/oauth/callback", cfg.Port) + cfg.OAuthConfig.RedirectURIs = append(cfg.OAuthConfig.RedirectURIs, localhostURI) + } + + // Add additional redirect URIs from command line flag + if additionalRedirectURIs != "" { + additionalURIs := strings.Split(additionalRedirectURIs, ",") + for _, uri := range additionalURIs { + trimmedURI := strings.TrimSpace(uri) + if trimmedURI != "" { + cfg.OAuthConfig.RedirectURIs = append(cfg.OAuthConfig.RedirectURIs, trimmedURI) + } + } + } + } + + // Parse allowed CORS origins for OAuth endpoints + if allowedCORSOrigins != "" { + log.Printf("OAuth Config: Setting allowed CORS origins from command line flag --oauth-cors-origins") + origins := strings.Split(allowedCORSOrigins, ",") + for _, origin := range origins { + trimmedOrigin := strings.TrimSpace(origin) + if trimmedOrigin != "" { + cfg.OAuthConfig.AllowedOrigins = append(cfg.OAuthConfig.AllowedOrigins, trimmedOrigin) + } + } + } else { + log.Printf("OAuth Config: No CORS origins configured - cross-origin requests will be blocked for security") + } + + return nil } // ValidateConfig validates the configuration for incompatible settings diff --git a/internal/config/config_test.go b/internal/config/config_test.go index a70c989..56ad2e4 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -5,23 +5,235 @@ import ( ) func TestBasicOAuthConfig(t *testing.T) { - // Test basic OAuth configuration parsing + // Test basic OAuth configuration parsing with valid GUIDs cfg := NewConfig() cfg.OAuthConfig.Enabled = true - cfg.OAuthConfig.TenantID = "test-tenant" - cfg.OAuthConfig.ClientID = "test-client" + cfg.OAuthConfig.TenantID = "12345678-1234-1234-1234-123456789abc" + cfg.OAuthConfig.ClientID = "87654321-4321-4321-4321-cba987654321" // Parse OAuth configuration - cfg.parseOAuthConfig() + if err := cfg.parseOAuthConfig("", ""); err != nil { + t.Fatalf("Unexpected error in parseOAuthConfig: %v", err) + } // Verify basic configuration is preserved if !cfg.OAuthConfig.Enabled { t.Error("Expected OAuth to be enabled") } - if cfg.OAuthConfig.TenantID != "test-tenant" { - t.Errorf("Expected tenant ID 'test-tenant', got %s", cfg.OAuthConfig.TenantID) + if cfg.OAuthConfig.TenantID != "12345678-1234-1234-1234-123456789abc" { + t.Errorf("Expected tenant ID '12345678-1234-1234-1234-123456789abc', got %s", cfg.OAuthConfig.TenantID) + } + if cfg.OAuthConfig.ClientID != "87654321-4321-4321-4321-cba987654321" { + t.Errorf("Expected client ID '87654321-4321-4321-4321-cba987654321', got %s", cfg.OAuthConfig.ClientID) } - if cfg.OAuthConfig.ClientID != "test-client" { - t.Errorf("Expected client ID 'test-client', got %s", cfg.OAuthConfig.ClientID) +} + +func TestOAuthRedirectURIsConfig(t *testing.T) { + // Test OAuth redirect URIs configuration with additional URIs + cfg := NewConfig() + cfg.OAuthConfig.Enabled = true + cfg.Host = "127.0.0.1" + cfg.Port = 8081 + + // Test with additional redirect URIs + additionalRedirectURIs := "http://localhost:6274/oauth/callback,http://localhost:8080/oauth/callback" + if err := cfg.parseOAuthConfig(additionalRedirectURIs, ""); err != nil { + t.Fatalf("Unexpected error in parseOAuthConfig: %v", err) + } + + // Should have default URIs plus additional ones + expectedURIs := []string{ + "http://127.0.0.1:8081/oauth/callback", + "http://localhost:8081/oauth/callback", + "http://localhost:6274/oauth/callback", + "http://localhost:8080/oauth/callback", + } + + if len(cfg.OAuthConfig.RedirectURIs) != len(expectedURIs) { + t.Errorf("Expected %d redirect URIs, got %d", len(expectedURIs), len(cfg.OAuthConfig.RedirectURIs)) + } + + for i, expected := range expectedURIs { + if i >= len(cfg.OAuthConfig.RedirectURIs) || cfg.OAuthConfig.RedirectURIs[i] != expected { + t.Errorf("Expected redirect URI '%s' at index %d, got '%s'", expected, i, + func() string { + if i < len(cfg.OAuthConfig.RedirectURIs) { + return cfg.OAuthConfig.RedirectURIs[i] + } + return "missing" + }()) + } + } +} + +func TestOAuthRedirectURIsEmptyAdditional(t *testing.T) { + // Test OAuth redirect URIs configuration without additional URIs + cfg := NewConfig() + cfg.OAuthConfig.Enabled = true + cfg.Host = "127.0.0.1" + cfg.Port = 8081 + + // Test with empty additional redirect URIs + if err := cfg.parseOAuthConfig("", ""); err != nil { + t.Fatalf("Unexpected error in parseOAuthConfig: %v", err) + } + + // Should have only default URIs + expectedURIs := []string{ + "http://127.0.0.1:8081/oauth/callback", + "http://localhost:8081/oauth/callback", + } + + if len(cfg.OAuthConfig.RedirectURIs) != len(expectedURIs) { + t.Errorf("Expected %d redirect URIs, got %d", len(expectedURIs), len(cfg.OAuthConfig.RedirectURIs)) + } + + for i, expected := range expectedURIs { + if cfg.OAuthConfig.RedirectURIs[i] != expected { + t.Errorf("Expected redirect URI '%s' at index %d, got '%s'", expected, i, cfg.OAuthConfig.RedirectURIs[i]) + } + } +} + +func TestValidateGUID(t *testing.T) { + tests := []struct { + name string + value string + fieldName string + wantErr bool + }{ + { + name: "valid GUID", + value: "12345678-1234-1234-1234-123456789abc", + fieldName: "test field", + wantErr: false, + }, + { + name: "valid GUID uppercase", + value: "12345678-1234-1234-1234-123456789ABC", + fieldName: "test field", + wantErr: false, + }, + { + name: "empty value allowed", + value: "", + fieldName: "test field", + wantErr: false, + }, + { + name: "invalid format - missing hyphens", + value: "123456781234123412341234567890ab", + fieldName: "test field", + wantErr: true, + }, + { + name: "invalid format - wrong length", + value: "12345678-1234-1234-1234-123456789", + fieldName: "test field", + wantErr: true, + }, + { + name: "invalid format - non-hex characters", + value: "12345678-1234-1234-1234-123456789abg", + fieldName: "test field", + wantErr: true, + }, + { + name: "invalid format - extra hyphens", + value: "12345678-1234-1234-1234-1234-56789abc", + fieldName: "test field", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateGUID(tt.value, tt.fieldName) + if (err != nil) != tt.wantErr { + t.Errorf("validateGUID() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr && err != nil { + // Verify error message contains the field name and value + errorMsg := err.Error() + if !contains(errorMsg, tt.fieldName) { + t.Errorf("Error message should contain field name '%s', got: %s", tt.fieldName, errorMsg) + } + if tt.value != "" && !contains(errorMsg, tt.value) { + t.Errorf("Error message should contain value '%s', got: %s", tt.value, errorMsg) + } + } + }) + } +} + +func TestOAuthGUIDValidation(t *testing.T) { + tests := []struct { + name string + tenantID string + clientID string + wantErr bool + }{ + { + name: "valid GUIDs", + tenantID: "12345678-1234-1234-1234-123456789abc", + clientID: "87654321-4321-4321-4321-cba987654321", + wantErr: false, + }, + { + name: "empty values allowed", + tenantID: "", + clientID: "", + wantErr: false, + }, + { + name: "invalid tenant ID", + tenantID: "invalid-tenant-id", + clientID: "87654321-4321-4321-4321-cba987654321", + wantErr: true, + }, + { + name: "invalid client ID", + tenantID: "12345678-1234-1234-1234-123456789abc", + clientID: "invalid-client-id", + wantErr: true, + }, + { + name: "both invalid", + tenantID: "invalid-tenant", + clientID: "invalid-client", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := NewConfig() + cfg.OAuthConfig.Enabled = true + cfg.OAuthConfig.TenantID = tt.tenantID + cfg.OAuthConfig.ClientID = tt.clientID + cfg.Host = "127.0.0.1" + cfg.Port = 8081 + + err := cfg.parseOAuthConfig("", "") + if (err != nil) != tt.wantErr { + t.Errorf("parseOAuthConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +// contains is a helper function to check if a string contains a substring +func contains(s, substr string) bool { + return len(substr) == 0 || (len(s) >= len(substr) && findSubstring(s, substr)) +} + +func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } } + return false }