Skip to content

Commit 059fedb

Browse files
authored
[Feat] Agent Gateway - Track agent_id in SpendLogs (#17795)
* add agent_id in metadata in spend logs * add agent_id in SpendLogsPayload * add agent_id in SpendLogsPayload * add _set_agent_id_on_logging_obj * add agent id tracking in SpendLogs * add agent id in spend logs * fix create_a2a_client * test_asend_message_passes_agent_id_to_callback * test_get_logging_payload_includes_agent_id_from_kwargs * test_asend_message_streaming_triggers_callbacks * fix asend_message_streaming * asend_message_streaming * A2AStreamingIterator * _handle_stream_message * test_asend_message_streaming_propagates_metadata
1 parent 5d456bc commit 059fedb

File tree

10 files changed

+347
-15
lines changed

10 files changed

+347
-15
lines changed

litellm-proxy-extras/litellm_proxy_extras/schema.prisma

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ model LiteLLM_SpendLogs {
315315
session_id String?
316316
status String?
317317
mcp_namespaced_tool_name String?
318+
agent_id String?
318319
proxy_server_request Json? @default("{}")
319320
@@index([startTime])
320321
@@index([end_user])

litellm/a2a_protocol/main.py

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
"""
66

77
import asyncio
8+
import datetime
89
from typing import TYPE_CHECKING, Any, AsyncIterator, Coroutine, Dict, Optional, Union
910

1011
import litellm
1112
from litellm._logging import verbose_logger
13+
from litellm.a2a_protocol.streaming_iterator import A2AStreamingIterator
1214
from litellm.a2a_protocol.utils import A2ARequestUtils
15+
from litellm.litellm_core_utils.litellm_logging import Logging
1316
from litellm.llms.custom_httpx.http_handler import (
1417
get_async_httpx_client,
1518
httpxSpecialProvider,
@@ -63,6 +66,26 @@ def _set_usage_on_logging_obj(
6366
litellm_logging_obj.model_call_details["usage"] = usage
6467

6568

69+
def _set_agent_id_on_logging_obj(
70+
kwargs: Dict[str, Any],
71+
agent_id: Optional[str],
72+
) -> None:
73+
"""
74+
Set agent_id on litellm_logging_obj for SpendLogs tracking.
75+
76+
Args:
77+
kwargs: The kwargs dict containing litellm_logging_obj
78+
agent_id: The A2A agent ID
79+
"""
80+
if agent_id is None:
81+
return
82+
83+
litellm_logging_obj = kwargs.get("litellm_logging_obj")
84+
if litellm_logging_obj is not None:
85+
# Set agent_id directly on model_call_details (same pattern as custom_llm_provider)
86+
litellm_logging_obj.model_call_details["agent_id"] = agent_id
87+
88+
6689
def _get_a2a_model_info(a2a_client: Any, kwargs: Dict[str, Any]) -> str:
6790
"""
6891
Extract agent info and set model/custom_llm_provider for cost tracking.
@@ -101,6 +124,7 @@ async def asend_message(
101124
request: Optional["SendMessageRequest"] = None,
102125
api_base: Optional[str] = None,
103126
litellm_params: Optional[Dict[str, Any]] = None,
127+
agent_id: Optional[str] = None,
104128
**kwargs: Any,
105129
) -> LiteLLMSendMessageResponse:
106130
"""
@@ -114,6 +138,7 @@ async def asend_message(
114138
request: SendMessageRequest from a2a.types (optional if using completion bridge with api_base)
115139
api_base: API base URL (required for completion bridge, optional for standard A2A)
116140
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
141+
agent_id: Optional agent ID for tracking in SpendLogs
117142
**kwargs: Additional arguments passed to the client decorator
118143
119144
Returns:
@@ -186,11 +211,15 @@ async def asend_message(
186211
return LiteLLMSendMessageResponse.from_dict(response_dict)
187212

188213
# Standard A2A client flow
189-
if a2a_client is None:
190-
raise ValueError("a2a_client is required for standard A2A flow")
191214
if request is None:
192215
raise ValueError("request is required")
193216

217+
# Create A2A client if not provided but api_base is available
218+
if a2a_client is None:
219+
if api_base is None:
220+
raise ValueError("Either a2a_client or api_base is required for standard A2A flow")
221+
a2a_client = await create_a2a_client(base_url=api_base)
222+
194223
agent_name = _get_a2a_model_info(a2a_client, kwargs)
195224

196225
verbose_logger.info(f"A2A send_message request_id={request.id}, agent={agent_name}")
@@ -216,6 +245,9 @@ async def asend_message(
216245
completion_tokens=completion_tokens,
217246
)
218247

248+
# Set agent_id on logging obj for SpendLogs tracking
249+
_set_agent_id_on_logging_obj(kwargs=kwargs, agent_id=agent_id)
250+
219251
return response
220252

221253

@@ -254,6 +286,9 @@ async def asend_message_streaming(
254286
request: Optional["SendStreamingMessageRequest"] = None,
255287
api_base: Optional[str] = None,
256288
litellm_params: Optional[Dict[str, Any]] = None,
289+
agent_id: Optional[str] = None,
290+
metadata: Optional[Dict[str, Any]] = None,
291+
proxy_server_request: Optional[Dict[str, Any]] = None,
257292
) -> AsyncIterator[Any]:
258293
"""
259294
Async: Send a streaming message to an A2A agent.
@@ -265,6 +300,9 @@ async def asend_message_streaming(
265300
request: SendStreamingMessageRequest from a2a.types
266301
api_base: API base URL (required for completion bridge)
267302
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
303+
agent_id: Optional agent ID for tracking in SpendLogs
304+
metadata: Optional metadata dict (contains user_api_key, user_id, team_id, etc.)
305+
proxy_server_request: Optional proxy server request data
268306
269307
Yields:
270308
SendStreamingMessageResponse chunks from the agent
@@ -320,23 +358,66 @@ async def asend_message_streaming(
320358
return
321359

322360
# Standard A2A client flow
323-
if a2a_client is None:
324-
raise ValueError("a2a_client is required for standard A2A flow")
325361
if request is None:
326362
raise ValueError("request is required")
327363

364+
# Create A2A client if not provided but api_base is available
365+
if a2a_client is None:
366+
if api_base is None:
367+
raise ValueError("Either a2a_client or api_base is required for standard A2A flow")
368+
a2a_client = await create_a2a_client(base_url=api_base)
369+
328370
verbose_logger.info(f"A2A send_message_streaming request_id={request.id}")
329371

372+
# Track for logging
373+
import datetime
374+
375+
start_time = datetime.datetime.now()
330376
stream = a2a_client.send_message_streaming(request)
331377

332-
chunk_count = 0
333-
async for chunk in stream:
334-
chunk_count += 1
335-
yield chunk
378+
# Build logging object for streaming completion callbacks
379+
agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(a2a_client, "agent_card", None)
380+
agent_name = getattr(agent_card, "name", "unknown") if agent_card else "unknown"
381+
model = f"a2a_agent/{agent_name}"
336382

337-
verbose_logger.info(
338-
f"A2A send_message_streaming completed, request_id={request.id}, chunks={chunk_count}"
383+
logging_obj = Logging(
384+
model=model,
385+
messages=[{"role": "user", "content": "streaming-request"}],
386+
stream=False, # complete response logging after stream ends
387+
call_type="asend_message_streaming",
388+
start_time=start_time,
389+
litellm_call_id=str(request.id),
390+
function_id=str(request.id),
339391
)
392+
logging_obj.model = model
393+
logging_obj.custom_llm_provider = "a2a_agent"
394+
logging_obj.model_call_details["model"] = model
395+
logging_obj.model_call_details["custom_llm_provider"] = "a2a_agent"
396+
if agent_id:
397+
logging_obj.model_call_details["agent_id"] = agent_id
398+
399+
# Propagate litellm_params for spend logging (includes cost_per_query, etc.)
400+
_litellm_params = litellm_params.copy() if litellm_params else {}
401+
# Merge metadata into litellm_params.metadata (required for proxy cost tracking)
402+
if metadata:
403+
_litellm_params["metadata"] = metadata
404+
if proxy_server_request:
405+
_litellm_params["proxy_server_request"] = proxy_server_request
406+
407+
logging_obj.litellm_params = _litellm_params
408+
logging_obj.optional_params = _litellm_params # used by cost calc
409+
logging_obj.model_call_details["litellm_params"] = _litellm_params
410+
logging_obj.model_call_details["metadata"] = metadata or {}
411+
412+
iterator = A2AStreamingIterator(
413+
stream=stream,
414+
request=request,
415+
logging_obj=logging_obj,
416+
agent_name=agent_name,
417+
)
418+
419+
async for chunk in iterator:
420+
yield chunk
340421

341422

342423
async def create_a2a_client(

litellm/a2a_protocol/streaming_iterator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import litellm
1010
from litellm._logging import verbose_logger
11+
from litellm.a2a_protocol.cost_calculator import A2ACostCalculator
1112
from litellm.a2a_protocol.utils import A2ARequestUtils
1213
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
1314
from litellm.litellm_core_utils.thread_pool_executor import executor
@@ -115,11 +116,17 @@ async def _handle_stream_complete(self) -> None:
115116

116117
# Set usage on logging obj
117118
self.logging_obj.model_call_details["usage"] = usage
119+
# Mark stream flag for downstream callbacks
120+
self.logging_obj.model_call_details["stream"] = False
121+
122+
# Calculate cost using A2ACostCalculator
123+
response_cost = A2ACostCalculator.calculate_a2a_cost(self.logging_obj)
124+
self.logging_obj.model_call_details["response_cost"] = response_cost
118125

119126
# Build result for logging
120127
result = self._build_logging_result(usage)
121128

122-
# Call success handlers
129+
# Call success handlers - they will build standard_logging_object
123130
asyncio.create_task(
124131
self.logging_obj.async_success_handler(
125132
result=result,
@@ -139,7 +146,8 @@ async def _handle_stream_complete(self) -> None:
139146

140147
verbose_logger.info(
141148
f"A2A streaming completed: prompt_tokens={prompt_tokens}, "
142-
f"completion_tokens={completion_tokens}, total_tokens={total_tokens}"
149+
f"completion_tokens={completion_tokens}, total_tokens={total_tokens}, "
150+
f"response_cost={response_cost}"
143151
)
144152

145153
except Exception as e:

litellm/proxy/_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2677,6 +2677,7 @@ class SpendLogsPayload(TypedDict):
26772677
model_id: Optional[str]
26782678
model_group: Optional[str]
26792679
mcp_namespaced_tool_name: Optional[str]
2680+
agent_id: Optional[str]
26802681
api_base: str
26812682
user: str
26822683
metadata: str # json str

litellm/proxy/agent_endpoints/a2a_endpoints.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ async def _handle_stream_message(
5050
request_id: str,
5151
params: dict,
5252
litellm_params: Optional[dict] = None,
53+
agent_id: Optional[str] = None,
54+
metadata: Optional[dict] = None,
55+
proxy_server_request: Optional[dict] = None,
5356
) -> StreamingResponse:
5457
"""Handle message/stream method via SDK functions."""
5558
from a2a.types import MessageSendParams, SendStreamingMessageRequest
@@ -66,6 +69,9 @@ async def stream_response():
6669
request=a2a_request,
6770
api_base=api_base,
6871
litellm_params=litellm_params,
72+
agent_id=agent_id,
73+
metadata=metadata,
74+
proxy_server_request=proxy_server_request,
6975
):
7076
# Chunk may be dict or object depending on bridge vs standard path
7177
if hasattr(chunk, "model_dump"):
@@ -241,6 +247,7 @@ async def invoke_agent_a2a(
241247
request=a2a_request,
242248
api_base=agent_url,
243249
litellm_params=litellm_params,
250+
agent_id=agent.agent_id,
244251
metadata=data.get("metadata", {}),
245252
proxy_server_request=data.get("proxy_server_request"),
246253
)
@@ -252,6 +259,9 @@ async def invoke_agent_a2a(
252259
request_id=request_id,
253260
params=params,
254261
litellm_params=litellm_params,
262+
agent_id=agent.agent_id,
263+
metadata=data.get("metadata", {}),
264+
proxy_server_request=data.get("proxy_server_request"),
255265
)
256266
else:
257267
return _jsonrpc_error(request_id, -32601, f"Method '{method}' not found")

litellm/proxy/schema.prisma

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ model LiteLLM_SpendLogs {
315315
session_id String?
316316
status String?
317317
mcp_namespaced_tool_name String?
318+
agent_id String?
318319
proxy_server_request Json? @default("{}")
319320
@@index([startTime])
320321
@@index([end_user])

litellm/proxy/spend_tracking/spend_tracking_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,16 @@ def get_logging_payload( # noqa: PLR0915
225225
response_obj_dict = {}
226226

227227
# Handle OCR responses which use usage_info instead of usage
228+
usage: dict = {}
228229
if call_type in ["ocr", "aocr"]:
229230
usage = _extract_usage_for_ocr_call(response_obj, response_obj_dict)
230231
else:
231232
# Use response_obj_dict instead of response_obj to avoid calling .get() on Pydantic models
232-
usage = response_obj_dict.get("usage", None) or {}
233-
if isinstance(usage, litellm.Usage):
234-
usage = dict(usage)
233+
_usage = response_obj_dict.get("usage", None) or {}
234+
if isinstance(_usage, litellm.Usage):
235+
usage = dict(_usage)
236+
elif isinstance(_usage, dict):
237+
usage = _usage
235238

236239
id = get_spend_logs_id(call_type or "acompletion", response_obj_dict, kwargs)
237240
standard_logging_payload = cast(
@@ -369,6 +372,9 @@ def get_logging_payload( # noqa: PLR0915
369372
"namespaced_tool_name", None
370373
)
371374

375+
# Extract agent_id for A2A requests (set directly on model_call_details)
376+
agent_id: Optional[str] = kwargs.get("agent_id")
377+
372378
try:
373379
payload: SpendLogsPayload = SpendLogsPayload(
374380
request_id=str(id),
@@ -396,6 +402,7 @@ def get_logging_payload( # noqa: PLR0915
396402
model_group=_model_group,
397403
model_id=_model_id,
398404
mcp_namespaced_tool_name=mcp_namespaced_tool_name,
405+
agent_id=agent_id,
399406
requester_ip_address=clean_metadata.get("requester_ip_address", None),
400407
custom_llm_provider=kwargs.get("custom_llm_provider", ""),
401408
messages=_get_messages_for_spend_logs_payload(

schema.prisma

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ model LiteLLM_SpendLogs {
315315
session_id String?
316316
status String?
317317
mcp_namespaced_tool_name String?
318+
agent_id String?
318319
proxy_server_request Json? @default("{}")
319320
@@index([startTime])
320321
@@index([end_user])

0 commit comments

Comments
 (0)