55"""
66
77import asyncio
8+ import datetime
89from typing import TYPE_CHECKING , Any , AsyncIterator , Coroutine , Dict , Optional , Union
910
1011import litellm
1112from litellm ._logging import verbose_logger
13+ from litellm .a2a_protocol .streaming_iterator import A2AStreamingIterator
1214from litellm .a2a_protocol .utils import A2ARequestUtils
15+ from litellm .litellm_core_utils .litellm_logging import Logging
1316from litellm .llms .custom_httpx .http_handler import (
1417 get_async_httpx_client ,
1518 httpxSpecialProvider ,
@@ -63,6 +66,26 @@ def _set_usage_on_logging_obj(
6366 litellm_logging_obj .model_call_details ["usage" ] = usage
6467
6568
69+ def _set_agent_id_on_logging_obj (
70+ kwargs : Dict [str , Any ],
71+ agent_id : Optional [str ],
72+ ) -> None :
73+ """
74+ Set agent_id on litellm_logging_obj for SpendLogs tracking.
75+
76+ Args:
77+ kwargs: The kwargs dict containing litellm_logging_obj
78+ agent_id: The A2A agent ID
79+ """
80+ if agent_id is None :
81+ return
82+
83+ litellm_logging_obj = kwargs .get ("litellm_logging_obj" )
84+ if litellm_logging_obj is not None :
85+ # Set agent_id directly on model_call_details (same pattern as custom_llm_provider)
86+ litellm_logging_obj .model_call_details ["agent_id" ] = agent_id
87+
88+
6689def _get_a2a_model_info (a2a_client : Any , kwargs : Dict [str , Any ]) -> str :
6790 """
6891 Extract agent info and set model/custom_llm_provider for cost tracking.
@@ -101,6 +124,7 @@ async def asend_message(
101124 request : Optional ["SendMessageRequest" ] = None ,
102125 api_base : Optional [str ] = None ,
103126 litellm_params : Optional [Dict [str , Any ]] = None ,
127+ agent_id : Optional [str ] = None ,
104128 ** kwargs : Any ,
105129) -> LiteLLMSendMessageResponse :
106130 """
@@ -114,6 +138,7 @@ async def asend_message(
114138 request: SendMessageRequest from a2a.types (optional if using completion bridge with api_base)
115139 api_base: API base URL (required for completion bridge, optional for standard A2A)
116140 litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
141+ agent_id: Optional agent ID for tracking in SpendLogs
117142 **kwargs: Additional arguments passed to the client decorator
118143
119144 Returns:
@@ -186,11 +211,15 @@ async def asend_message(
186211 return LiteLLMSendMessageResponse .from_dict (response_dict )
187212
188213 # Standard A2A client flow
189- if a2a_client is None :
190- raise ValueError ("a2a_client is required for standard A2A flow" )
191214 if request is None :
192215 raise ValueError ("request is required" )
193216
217+ # Create A2A client if not provided but api_base is available
218+ if a2a_client is None :
219+ if api_base is None :
220+ raise ValueError ("Either a2a_client or api_base is required for standard A2A flow" )
221+ a2a_client = await create_a2a_client (base_url = api_base )
222+
194223 agent_name = _get_a2a_model_info (a2a_client , kwargs )
195224
196225 verbose_logger .info (f"A2A send_message request_id={ request .id } , agent={ agent_name } " )
@@ -216,6 +245,9 @@ async def asend_message(
216245 completion_tokens = completion_tokens ,
217246 )
218247
248+ # Set agent_id on logging obj for SpendLogs tracking
249+ _set_agent_id_on_logging_obj (kwargs = kwargs , agent_id = agent_id )
250+
219251 return response
220252
221253
@@ -254,6 +286,9 @@ async def asend_message_streaming(
254286 request : Optional ["SendStreamingMessageRequest" ] = None ,
255287 api_base : Optional [str ] = None ,
256288 litellm_params : Optional [Dict [str , Any ]] = None ,
289+ agent_id : Optional [str ] = None ,
290+ metadata : Optional [Dict [str , Any ]] = None ,
291+ proxy_server_request : Optional [Dict [str , Any ]] = None ,
257292) -> AsyncIterator [Any ]:
258293 """
259294 Async: Send a streaming message to an A2A agent.
@@ -265,6 +300,9 @@ async def asend_message_streaming(
265300 request: SendStreamingMessageRequest from a2a.types
266301 api_base: API base URL (required for completion bridge)
267302 litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
303+ agent_id: Optional agent ID for tracking in SpendLogs
304+ metadata: Optional metadata dict (contains user_api_key, user_id, team_id, etc.)
305+ proxy_server_request: Optional proxy server request data
268306
269307 Yields:
270308 SendStreamingMessageResponse chunks from the agent
@@ -320,23 +358,66 @@ async def asend_message_streaming(
320358 return
321359
322360 # Standard A2A client flow
323- if a2a_client is None :
324- raise ValueError ("a2a_client is required for standard A2A flow" )
325361 if request is None :
326362 raise ValueError ("request is required" )
327363
364+ # Create A2A client if not provided but api_base is available
365+ if a2a_client is None :
366+ if api_base is None :
367+ raise ValueError ("Either a2a_client or api_base is required for standard A2A flow" )
368+ a2a_client = await create_a2a_client (base_url = api_base )
369+
328370 verbose_logger .info (f"A2A send_message_streaming request_id={ request .id } " )
329371
372+ # Track for logging
373+ import datetime
374+
375+ start_time = datetime .datetime .now ()
330376 stream = a2a_client .send_message_streaming (request )
331377
332- chunk_count = 0
333- async for chunk in stream :
334- chunk_count += 1
335- yield chunk
378+ # Build logging object for streaming completion callbacks
379+ agent_card = getattr ( a2a_client , "_litellm_agent_card" , None ) or getattr ( a2a_client , "agent_card" , None )
380+ agent_name = getattr ( agent_card , "name" , "unknown" ) if agent_card else "unknown"
381+ model = f"a2a_agent/ { agent_name } "
336382
337- verbose_logger .info (
338- f"A2A send_message_streaming completed, request_id={ request .id } , chunks={ chunk_count } "
383+ logging_obj = Logging (
384+ model = model ,
385+ messages = [{"role" : "user" , "content" : "streaming-request" }],
386+ stream = False , # complete response logging after stream ends
387+ call_type = "asend_message_streaming" ,
388+ start_time = start_time ,
389+ litellm_call_id = str (request .id ),
390+ function_id = str (request .id ),
339391 )
392+ logging_obj .model = model
393+ logging_obj .custom_llm_provider = "a2a_agent"
394+ logging_obj .model_call_details ["model" ] = model
395+ logging_obj .model_call_details ["custom_llm_provider" ] = "a2a_agent"
396+ if agent_id :
397+ logging_obj .model_call_details ["agent_id" ] = agent_id
398+
399+ # Propagate litellm_params for spend logging (includes cost_per_query, etc.)
400+ _litellm_params = litellm_params .copy () if litellm_params else {}
401+ # Merge metadata into litellm_params.metadata (required for proxy cost tracking)
402+ if metadata :
403+ _litellm_params ["metadata" ] = metadata
404+ if proxy_server_request :
405+ _litellm_params ["proxy_server_request" ] = proxy_server_request
406+
407+ logging_obj .litellm_params = _litellm_params
408+ logging_obj .optional_params = _litellm_params # used by cost calc
409+ logging_obj .model_call_details ["litellm_params" ] = _litellm_params
410+ logging_obj .model_call_details ["metadata" ] = metadata or {}
411+
412+ iterator = A2AStreamingIterator (
413+ stream = stream ,
414+ request = request ,
415+ logging_obj = logging_obj ,
416+ agent_name = agent_name ,
417+ )
418+
419+ async for chunk in iterator :
420+ yield chunk
340421
341422
342423async def create_a2a_client (
0 commit comments