diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index e227c41f93ad..f876d63520bc 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -315,6 +315,7 @@ model LiteLLM_SpendLogs { session_id String? status String? mcp_namespaced_tool_name String? + agent_id String? proxy_server_request Json? @default("{}") @@index([startTime]) @@index([end_user]) diff --git a/litellm/a2a_protocol/main.py b/litellm/a2a_protocol/main.py index 4d89dbb94542..30e13acc1ff8 100644 --- a/litellm/a2a_protocol/main.py +++ b/litellm/a2a_protocol/main.py @@ -5,11 +5,14 @@ """ import asyncio +import datetime from typing import TYPE_CHECKING, Any, AsyncIterator, Coroutine, Dict, Optional, Union import litellm from litellm._logging import verbose_logger +from litellm.a2a_protocol.streaming_iterator import A2AStreamingIterator from litellm.a2a_protocol.utils import A2ARequestUtils +from litellm.litellm_core_utils.litellm_logging import Logging from litellm.llms.custom_httpx.http_handler import ( get_async_httpx_client, httpxSpecialProvider, @@ -63,6 +66,26 @@ def _set_usage_on_logging_obj( litellm_logging_obj.model_call_details["usage"] = usage +def _set_agent_id_on_logging_obj( + kwargs: Dict[str, Any], + agent_id: Optional[str], +) -> None: + """ + Set agent_id on litellm_logging_obj for SpendLogs tracking. + + Args: + kwargs: The kwargs dict containing litellm_logging_obj + agent_id: The A2A agent ID + """ + if agent_id is None: + return + + litellm_logging_obj = kwargs.get("litellm_logging_obj") + if litellm_logging_obj is not None: + # Set agent_id directly on model_call_details (same pattern as custom_llm_provider) + litellm_logging_obj.model_call_details["agent_id"] = agent_id + + def _get_a2a_model_info(a2a_client: Any, kwargs: Dict[str, Any]) -> str: """ Extract agent info and set model/custom_llm_provider for cost tracking. @@ -101,6 +124,7 @@ async def asend_message( request: Optional["SendMessageRequest"] = None, api_base: Optional[str] = None, litellm_params: Optional[Dict[str, Any]] = None, + agent_id: Optional[str] = None, **kwargs: Any, ) -> LiteLLMSendMessageResponse: """ @@ -114,6 +138,7 @@ async def asend_message( request: SendMessageRequest from a2a.types (optional if using completion bridge with api_base) api_base: API base URL (required for completion bridge, optional for standard A2A) litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge + agent_id: Optional agent ID for tracking in SpendLogs **kwargs: Additional arguments passed to the client decorator Returns: @@ -186,11 +211,15 @@ async def asend_message( return LiteLLMSendMessageResponse.from_dict(response_dict) # Standard A2A client flow - if a2a_client is None: - raise ValueError("a2a_client is required for standard A2A flow") if request is None: raise ValueError("request is required") + # Create A2A client if not provided but api_base is available + if a2a_client is None: + if api_base is None: + raise ValueError("Either a2a_client or api_base is required for standard A2A flow") + a2a_client = await create_a2a_client(base_url=api_base) + agent_name = _get_a2a_model_info(a2a_client, kwargs) verbose_logger.info(f"A2A send_message request_id={request.id}, agent={agent_name}") @@ -216,6 +245,9 @@ async def asend_message( completion_tokens=completion_tokens, ) + # Set agent_id on logging obj for SpendLogs tracking + _set_agent_id_on_logging_obj(kwargs=kwargs, agent_id=agent_id) + return response @@ -254,6 +286,9 @@ async def asend_message_streaming( request: Optional["SendStreamingMessageRequest"] = None, api_base: Optional[str] = None, litellm_params: Optional[Dict[str, Any]] = None, + agent_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + proxy_server_request: Optional[Dict[str, Any]] = None, ) -> AsyncIterator[Any]: """ Async: Send a streaming message to an A2A agent. @@ -265,6 +300,9 @@ async def asend_message_streaming( request: SendStreamingMessageRequest from a2a.types api_base: API base URL (required for completion bridge) litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge + agent_id: Optional agent ID for tracking in SpendLogs + metadata: Optional metadata dict (contains user_api_key, user_id, team_id, etc.) + proxy_server_request: Optional proxy server request data Yields: SendStreamingMessageResponse chunks from the agent @@ -320,23 +358,66 @@ async def asend_message_streaming( return # Standard A2A client flow - if a2a_client is None: - raise ValueError("a2a_client is required for standard A2A flow") if request is None: raise ValueError("request is required") + # Create A2A client if not provided but api_base is available + if a2a_client is None: + if api_base is None: + raise ValueError("Either a2a_client or api_base is required for standard A2A flow") + a2a_client = await create_a2a_client(base_url=api_base) + verbose_logger.info(f"A2A send_message_streaming request_id={request.id}") + # Track for logging + import datetime + + start_time = datetime.datetime.now() stream = a2a_client.send_message_streaming(request) - chunk_count = 0 - async for chunk in stream: - chunk_count += 1 - yield chunk + # Build logging object for streaming completion callbacks + agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(a2a_client, "agent_card", None) + agent_name = getattr(agent_card, "name", "unknown") if agent_card else "unknown" + model = f"a2a_agent/{agent_name}" - verbose_logger.info( - f"A2A send_message_streaming completed, request_id={request.id}, chunks={chunk_count}" + logging_obj = Logging( + model=model, + messages=[{"role": "user", "content": "streaming-request"}], + stream=False, # complete response logging after stream ends + call_type="asend_message_streaming", + start_time=start_time, + litellm_call_id=str(request.id), + function_id=str(request.id), ) + logging_obj.model = model + logging_obj.custom_llm_provider = "a2a_agent" + logging_obj.model_call_details["model"] = model + logging_obj.model_call_details["custom_llm_provider"] = "a2a_agent" + if agent_id: + logging_obj.model_call_details["agent_id"] = agent_id + + # Propagate litellm_params for spend logging (includes cost_per_query, etc.) + _litellm_params = litellm_params.copy() if litellm_params else {} + # Merge metadata into litellm_params.metadata (required for proxy cost tracking) + if metadata: + _litellm_params["metadata"] = metadata + if proxy_server_request: + _litellm_params["proxy_server_request"] = proxy_server_request + + logging_obj.litellm_params = _litellm_params + logging_obj.optional_params = _litellm_params # used by cost calc + logging_obj.model_call_details["litellm_params"] = _litellm_params + logging_obj.model_call_details["metadata"] = metadata or {} + + iterator = A2AStreamingIterator( + stream=stream, + request=request, + logging_obj=logging_obj, + agent_name=agent_name, + ) + + async for chunk in iterator: + yield chunk async def create_a2a_client( diff --git a/litellm/a2a_protocol/streaming_iterator.py b/litellm/a2a_protocol/streaming_iterator.py index d7bd27246a47..921dc0e52e09 100644 --- a/litellm/a2a_protocol/streaming_iterator.py +++ b/litellm/a2a_protocol/streaming_iterator.py @@ -8,6 +8,7 @@ import litellm from litellm._logging import verbose_logger +from litellm.a2a_protocol.cost_calculator import A2ACostCalculator from litellm.a2a_protocol.utils import A2ARequestUtils from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.thread_pool_executor import executor @@ -115,11 +116,17 @@ async def _handle_stream_complete(self) -> None: # Set usage on logging obj self.logging_obj.model_call_details["usage"] = usage + # Mark stream flag for downstream callbacks + self.logging_obj.model_call_details["stream"] = False + + # Calculate cost using A2ACostCalculator + response_cost = A2ACostCalculator.calculate_a2a_cost(self.logging_obj) + self.logging_obj.model_call_details["response_cost"] = response_cost # Build result for logging result = self._build_logging_result(usage) - # Call success handlers + # Call success handlers - they will build standard_logging_object asyncio.create_task( self.logging_obj.async_success_handler( result=result, @@ -139,7 +146,8 @@ async def _handle_stream_complete(self) -> None: verbose_logger.info( f"A2A streaming completed: prompt_tokens={prompt_tokens}, " - f"completion_tokens={completion_tokens}, total_tokens={total_tokens}" + f"completion_tokens={completion_tokens}, total_tokens={total_tokens}, " + f"response_cost={response_cost}" ) except Exception as e: diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 083ac07340ad..6e5d4f671117 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2677,6 +2677,7 @@ class SpendLogsPayload(TypedDict): model_id: Optional[str] model_group: Optional[str] mcp_namespaced_tool_name: Optional[str] + agent_id: Optional[str] api_base: str user: str metadata: str # json str diff --git a/litellm/proxy/agent_endpoints/a2a_endpoints.py b/litellm/proxy/agent_endpoints/a2a_endpoints.py index 84bbbb5d8a5a..90b1507b3863 100644 --- a/litellm/proxy/agent_endpoints/a2a_endpoints.py +++ b/litellm/proxy/agent_endpoints/a2a_endpoints.py @@ -50,6 +50,9 @@ async def _handle_stream_message( request_id: str, params: dict, litellm_params: Optional[dict] = None, + agent_id: Optional[str] = None, + metadata: Optional[dict] = None, + proxy_server_request: Optional[dict] = None, ) -> StreamingResponse: """Handle message/stream method via SDK functions.""" from a2a.types import MessageSendParams, SendStreamingMessageRequest @@ -66,6 +69,9 @@ async def stream_response(): request=a2a_request, api_base=api_base, litellm_params=litellm_params, + agent_id=agent_id, + metadata=metadata, + proxy_server_request=proxy_server_request, ): # Chunk may be dict or object depending on bridge vs standard path if hasattr(chunk, "model_dump"): @@ -241,6 +247,7 @@ async def invoke_agent_a2a( request=a2a_request, api_base=agent_url, litellm_params=litellm_params, + agent_id=agent.agent_id, metadata=data.get("metadata", {}), proxy_server_request=data.get("proxy_server_request"), ) @@ -252,6 +259,9 @@ async def invoke_agent_a2a( request_id=request_id, params=params, litellm_params=litellm_params, + agent_id=agent.agent_id, + metadata=data.get("metadata", {}), + proxy_server_request=data.get("proxy_server_request"), ) else: return _jsonrpc_error(request_id, -32601, f"Method '{method}' not found") diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index e227c41f93ad..f876d63520bc 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -315,6 +315,7 @@ model LiteLLM_SpendLogs { session_id String? status String? mcp_namespaced_tool_name String? + agent_id String? proxy_server_request Json? @default("{}") @@index([startTime]) @@index([end_user]) diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index ddb9cb90e3c1..090d870ba725 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -225,13 +225,16 @@ def get_logging_payload( # noqa: PLR0915 response_obj_dict = {} # Handle OCR responses which use usage_info instead of usage + usage: dict = {} if call_type in ["ocr", "aocr"]: usage = _extract_usage_for_ocr_call(response_obj, response_obj_dict) else: # Use response_obj_dict instead of response_obj to avoid calling .get() on Pydantic models - usage = response_obj_dict.get("usage", None) or {} - if isinstance(usage, litellm.Usage): - usage = dict(usage) + _usage = response_obj_dict.get("usage", None) or {} + if isinstance(_usage, litellm.Usage): + usage = dict(_usage) + elif isinstance(_usage, dict): + usage = _usage id = get_spend_logs_id(call_type or "acompletion", response_obj_dict, kwargs) standard_logging_payload = cast( @@ -369,6 +372,9 @@ def get_logging_payload( # noqa: PLR0915 "namespaced_tool_name", None ) + # Extract agent_id for A2A requests (set directly on model_call_details) + agent_id: Optional[str] = kwargs.get("agent_id") + try: payload: SpendLogsPayload = SpendLogsPayload( request_id=str(id), @@ -396,6 +402,7 @@ def get_logging_payload( # noqa: PLR0915 model_group=_model_group, model_id=_model_id, mcp_namespaced_tool_name=mcp_namespaced_tool_name, + agent_id=agent_id, requester_ip_address=clean_metadata.get("requester_ip_address", None), custom_llm_provider=kwargs.get("custom_llm_provider", ""), messages=_get_messages_for_spend_logs_payload( diff --git a/schema.prisma b/schema.prisma index e227c41f93ad..f876d63520bc 100644 --- a/schema.prisma +++ b/schema.prisma @@ -315,6 +315,7 @@ model LiteLLM_SpendLogs { session_id String? status String? mcp_namespaced_tool_name String? + agent_id String? proxy_server_request Json? @default("{}") @@index([startTime]) @@index([end_user]) diff --git a/tests/test_litellm/a2a_protocol/test_cost_calculator.py b/tests/test_litellm/a2a_protocol/test_cost_calculator.py index 9f0948dc5099..0a472c089b1b 100644 --- a/tests/test_litellm/a2a_protocol/test_cost_calculator.py +++ b/tests/test_litellm/a2a_protocol/test_cost_calculator.py @@ -164,3 +164,187 @@ async def test_asend_message_uses_input_output_cost_per_token(): # Verify exact cost calculation assert response_cost == expected_cost, f"response_cost {response_cost} should equal expected {expected_cost}" + + +class AgentIdLogger(CustomLogger): + """Custom logger to capture agent_id from kwargs.""" + + def __init__(self): + self.agent_id: Optional[str] = None + self.kwargs: Optional[dict] = None + super().__init__() + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + self.kwargs = kwargs + self.agent_id = kwargs.get("agent_id") + + +@pytest.mark.asyncio +async def test_asend_message_passes_agent_id_to_callback(): + """ + Test that asend_message passes agent_id to callbacks via kwargs. + """ + from litellm.a2a_protocol import asend_message + + # Setup logger + litellm.logging_callback_manager._reset_all_callbacks() + agent_id_logger = AgentIdLogger() + litellm.callbacks = [agent_id_logger] + + # Mock A2A client + mock_client = MagicMock() + mock_client._litellm_agent_card = MagicMock() + mock_client._litellm_agent_card.name = "test-agent" + + # Mock response + mock_response = MagicMock() + mock_response.model_dump = MagicMock(return_value={ + "id": "test-123", + "jsonrpc": "2.0", + "result": {"status": "completed"}, + }) + mock_client.send_message = AsyncMock(return_value=mock_response) + + # Mock request + mock_request = MagicMock() + mock_request.id = "test-123" + + test_agent_id = "agent-uuid-12345" + + # Call asend_message with agent_id + await asend_message( + a2a_client=mock_client, + request=mock_request, + agent_id=test_agent_id, + ) + + await asyncio.sleep(0.1) + + # Verify agent_id was passed to callback + assert agent_id_logger.agent_id == test_agent_id, f"Expected agent_id '{test_agent_id}', got '{agent_id_logger.agent_id}'" + + +class MetadataLogger(CustomLogger): + """Custom logger to capture metadata from kwargs for proxy spend tracking.""" + + def __init__(self): + self.metadata: Optional[dict] = None + self.litellm_params: Optional[dict] = None + self.user_api_key: Optional[str] = None + self.user_id: Optional[str] = None + self.team_id: Optional[str] = None + super().__init__() + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + self.litellm_params = kwargs.get("litellm_params", {}) + self.metadata = self.litellm_params.get("metadata", {}) + self.user_api_key = self.metadata.get("user_api_key") + self.user_id = self.metadata.get("user_api_key_user_id") + self.team_id = self.metadata.get("user_api_key_team_id") + + +@pytest.mark.asyncio +async def test_asend_message_streaming_propagates_metadata(): + """ + Test that asend_message_streaming propagates metadata to logging object. + This ensures user_api_key, user_id, team_id are available for SpendLogs. + """ + from litellm.a2a_protocol import asend_message_streaming + + # Setup logger + litellm.logging_callback_manager._reset_all_callbacks() + metadata_logger = MetadataLogger() + litellm.logging_callback_manager.add_litellm_async_success_callback(metadata_logger) + + # Mock A2A client + mock_client = MagicMock() + mock_client._litellm_agent_card = MagicMock() + mock_client._litellm_agent_card.name = "test-agent" + + # Mock streaming response + async def mock_stream(): + yield MagicMock(model_dump=lambda mode, exclude_none: {"chunk": 1}) + yield MagicMock(model_dump=lambda mode, exclude_none: {"chunk": 2}) + + mock_client.send_message_streaming = MagicMock(return_value=mock_stream()) + + # Mock request + mock_request = MagicMock() + mock_request.id = "test-stream-metadata" + mock_request.params = MagicMock() + mock_request.params.message = {"role": "user", "parts": [{"kind": "text", "text": "Hello"}]} + + # Metadata from proxy (contains user_api_key, user_id, team_id for SpendLogs) + test_metadata = { + "user_api_key": "sk-test-key-hash-12345", + "user_api_key_user_id": "user-uuid-123", + "user_api_key_team_id": "team-uuid-456", + } + + # Consume streaming response with metadata + chunks = [] + async for chunk in asend_message_streaming( + a2a_client=mock_client, + request=mock_request, + metadata=test_metadata, + ): + chunks.append(chunk) + + await asyncio.sleep(0.2) + + # Verify metadata was propagated to callback + assert metadata_logger.user_api_key == "sk-test-key-hash-12345" + assert metadata_logger.user_id == "user-uuid-123" + assert metadata_logger.team_id == "team-uuid-456" + + +@pytest.mark.asyncio +async def test_asend_message_streaming_triggers_callbacks(): + """ + Test that asend_message_streaming triggers callbacks after stream completes. + """ + from litellm.a2a_protocol import asend_message_streaming + + # Setup logger - must use logging_callback_manager to properly register + litellm.logging_callback_manager._reset_all_callbacks() + callback_logger = AgentIdLogger() + litellm.logging_callback_manager.add_litellm_async_success_callback(callback_logger) + litellm.logging_callback_manager.add_litellm_success_callback(callback_logger) + + # Mock A2A client + mock_client = MagicMock() + mock_client._litellm_agent_card = MagicMock() + mock_client._litellm_agent_card.name = "test-agent" + + # Mock streaming response + async def mock_stream(): + yield MagicMock(model_dump=lambda mode, exclude_none: {"chunk": 1}) + yield MagicMock(model_dump=lambda mode, exclude_none: {"chunk": 2}) + + mock_client.send_message_streaming = MagicMock(return_value=mock_stream()) + + # Mock request + mock_request = MagicMock() + mock_request.id = "test-stream-123" + mock_request.params = MagicMock() + mock_request.params.message = {"role": "user", "parts": [{"kind": "text", "text": "Hello"}]} + + test_agent_id = "test-agent-id-streaming" + + # Consume streaming response + chunks = [] + async for chunk in asend_message_streaming( + a2a_client=mock_client, + request=mock_request, + agent_id=test_agent_id, + ): + chunks.append(chunk) + + await asyncio.sleep(0.2) + + # Verify chunks were received + assert len(chunks) == 2 + + # Verify callbacks WERE triggered after stream completed + assert callback_logger.kwargs is not None, "Streaming should trigger callbacks after completion" + assert callback_logger.agent_id == test_agent_id, f"Expected agent_id '{test_agent_id}', got '{callback_logger.agent_id}'" diff --git a/tests/test_litellm/proxy/spend_tracking/test_spend_tracking_utils.py b/tests/test_litellm/proxy/spend_tracking/test_spend_tracking_utils.py index 1c137c43ba93..5adf0bb1a3de 100644 --- a/tests/test_litellm/proxy/spend_tracking/test_spend_tracking_utils.py +++ b/tests/test_litellm/proxy/spend_tracking/test_spend_tracking_utils.py @@ -594,3 +594,41 @@ async def mock_update_database( print("- Both SpendLogs AND DailyUserSpend will have correct api_key") print("="*80 + "\n") + +@patch("litellm.proxy.proxy_server.master_key", None) +@patch("litellm.proxy.proxy_server.general_settings", {}) +def test_get_logging_payload_includes_agent_id_from_kwargs(): + """ + Test that get_logging_payload extracts agent_id from kwargs and includes it in the payload. + """ + test_agent_id = "agent-uuid-12345" + + kwargs = { + "model": "a2a_agent/test-agent", + "custom_llm_provider": "a2a_agent", + "agent_id": test_agent_id, + "litellm_params": { + "metadata": { + "user_api_key": "sk-test-key", + } + }, + } + + response_obj = { + "id": "test-response-123", + "jsonrpc": "2.0", + "result": {"status": "completed"}, + } + + start_time = datetime.datetime.now(timezone.utc) + end_time = datetime.datetime.now(timezone.utc) + + payload = get_logging_payload( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + ) + + assert payload["agent_id"] == test_agent_id, f"Expected agent_id '{test_agent_id}', got '{payload.get('agent_id')}'" +