Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions litellm-proxy-extras/litellm_proxy_extras/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ model LiteLLM_SpendLogs {
session_id String?
status String?
mcp_namespaced_tool_name String?
agent_id String?
proxy_server_request Json? @default("{}")
@@index([startTime])
@@index([end_user])
Expand Down
101 changes: 91 additions & 10 deletions litellm/a2a_protocol/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
"""

import asyncio
import datetime
from typing import TYPE_CHECKING, Any, AsyncIterator, Coroutine, Dict, Optional, Union

import litellm
from litellm._logging import verbose_logger
from litellm.a2a_protocol.streaming_iterator import A2AStreamingIterator
from litellm.a2a_protocol.utils import A2ARequestUtils
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
Expand Down Expand Up @@ -63,6 +66,26 @@ def _set_usage_on_logging_obj(
litellm_logging_obj.model_call_details["usage"] = usage


def _set_agent_id_on_logging_obj(
kwargs: Dict[str, Any],
agent_id: Optional[str],
) -> None:
"""
Set agent_id on litellm_logging_obj for SpendLogs tracking.

Args:
kwargs: The kwargs dict containing litellm_logging_obj
agent_id: The A2A agent ID
"""
if agent_id is None:
return

litellm_logging_obj = kwargs.get("litellm_logging_obj")
if litellm_logging_obj is not None:
# Set agent_id directly on model_call_details (same pattern as custom_llm_provider)
litellm_logging_obj.model_call_details["agent_id"] = agent_id


def _get_a2a_model_info(a2a_client: Any, kwargs: Dict[str, Any]) -> str:
"""
Extract agent info and set model/custom_llm_provider for cost tracking.
Expand Down Expand Up @@ -101,6 +124,7 @@ async def asend_message(
request: Optional["SendMessageRequest"] = None,
api_base: Optional[str] = None,
litellm_params: Optional[Dict[str, Any]] = None,
agent_id: Optional[str] = None,
**kwargs: Any,
) -> LiteLLMSendMessageResponse:
"""
Expand All @@ -114,6 +138,7 @@ async def asend_message(
request: SendMessageRequest from a2a.types (optional if using completion bridge with api_base)
api_base: API base URL (required for completion bridge, optional for standard A2A)
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
agent_id: Optional agent ID for tracking in SpendLogs
**kwargs: Additional arguments passed to the client decorator

Returns:
Expand Down Expand Up @@ -186,11 +211,15 @@ async def asend_message(
return LiteLLMSendMessageResponse.from_dict(response_dict)

# Standard A2A client flow
if a2a_client is None:
raise ValueError("a2a_client is required for standard A2A flow")
if request is None:
raise ValueError("request is required")

# Create A2A client if not provided but api_base is available
if a2a_client is None:
if api_base is None:
raise ValueError("Either a2a_client or api_base is required for standard A2A flow")
a2a_client = await create_a2a_client(base_url=api_base)

agent_name = _get_a2a_model_info(a2a_client, kwargs)

verbose_logger.info(f"A2A send_message request_id={request.id}, agent={agent_name}")
Expand All @@ -216,6 +245,9 @@ async def asend_message(
completion_tokens=completion_tokens,
)

# Set agent_id on logging obj for SpendLogs tracking
_set_agent_id_on_logging_obj(kwargs=kwargs, agent_id=agent_id)

return response


Expand Down Expand Up @@ -254,6 +286,9 @@ async def asend_message_streaming(
request: Optional["SendStreamingMessageRequest"] = None,
api_base: Optional[str] = None,
litellm_params: Optional[Dict[str, Any]] = None,
agent_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
proxy_server_request: Optional[Dict[str, Any]] = None,
) -> AsyncIterator[Any]:
"""
Async: Send a streaming message to an A2A agent.
Expand All @@ -265,6 +300,9 @@ async def asend_message_streaming(
request: SendStreamingMessageRequest from a2a.types
api_base: API base URL (required for completion bridge)
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
agent_id: Optional agent ID for tracking in SpendLogs
metadata: Optional metadata dict (contains user_api_key, user_id, team_id, etc.)
proxy_server_request: Optional proxy server request data

Yields:
SendStreamingMessageResponse chunks from the agent
Expand Down Expand Up @@ -320,23 +358,66 @@ async def asend_message_streaming(
return

# Standard A2A client flow
if a2a_client is None:
raise ValueError("a2a_client is required for standard A2A flow")
if request is None:
raise ValueError("request is required")

# Create A2A client if not provided but api_base is available
if a2a_client is None:
if api_base is None:
raise ValueError("Either a2a_client or api_base is required for standard A2A flow")
a2a_client = await create_a2a_client(base_url=api_base)

verbose_logger.info(f"A2A send_message_streaming request_id={request.id}")

# Track for logging
import datetime

start_time = datetime.datetime.now()
stream = a2a_client.send_message_streaming(request)

chunk_count = 0
async for chunk in stream:
chunk_count += 1
yield chunk
# Build logging object for streaming completion callbacks
agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(a2a_client, "agent_card", None)
agent_name = getattr(agent_card, "name", "unknown") if agent_card else "unknown"
model = f"a2a_agent/{agent_name}"

verbose_logger.info(
f"A2A send_message_streaming completed, request_id={request.id}, chunks={chunk_count}"
logging_obj = Logging(
model=model,
messages=[{"role": "user", "content": "streaming-request"}],
stream=False, # complete response logging after stream ends
call_type="asend_message_streaming",
start_time=start_time,
litellm_call_id=str(request.id),
function_id=str(request.id),
)
logging_obj.model = model
logging_obj.custom_llm_provider = "a2a_agent"
logging_obj.model_call_details["model"] = model
logging_obj.model_call_details["custom_llm_provider"] = "a2a_agent"
if agent_id:
logging_obj.model_call_details["agent_id"] = agent_id

# Propagate litellm_params for spend logging (includes cost_per_query, etc.)
_litellm_params = litellm_params.copy() if litellm_params else {}
# Merge metadata into litellm_params.metadata (required for proxy cost tracking)
if metadata:
_litellm_params["metadata"] = metadata
if proxy_server_request:
_litellm_params["proxy_server_request"] = proxy_server_request

logging_obj.litellm_params = _litellm_params
logging_obj.optional_params = _litellm_params # used by cost calc
logging_obj.model_call_details["litellm_params"] = _litellm_params
logging_obj.model_call_details["metadata"] = metadata or {}

iterator = A2AStreamingIterator(
stream=stream,
request=request,
logging_obj=logging_obj,
agent_name=agent_name,
)

async for chunk in iterator:
yield chunk


async def create_a2a_client(
Expand Down
12 changes: 10 additions & 2 deletions litellm/a2a_protocol/streaming_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import litellm
from litellm._logging import verbose_logger
from litellm.a2a_protocol.cost_calculator import A2ACostCalculator
from litellm.a2a_protocol.utils import A2ARequestUtils
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.thread_pool_executor import executor
Expand Down Expand Up @@ -115,11 +116,17 @@ async def _handle_stream_complete(self) -> None:

# Set usage on logging obj
self.logging_obj.model_call_details["usage"] = usage
# Mark stream flag for downstream callbacks
self.logging_obj.model_call_details["stream"] = False

# Calculate cost using A2ACostCalculator
response_cost = A2ACostCalculator.calculate_a2a_cost(self.logging_obj)
self.logging_obj.model_call_details["response_cost"] = response_cost

# Build result for logging
result = self._build_logging_result(usage)

# Call success handlers
# Call success handlers - they will build standard_logging_object
asyncio.create_task(
self.logging_obj.async_success_handler(
result=result,
Expand All @@ -139,7 +146,8 @@ async def _handle_stream_complete(self) -> None:

verbose_logger.info(
f"A2A streaming completed: prompt_tokens={prompt_tokens}, "
f"completion_tokens={completion_tokens}, total_tokens={total_tokens}"
f"completion_tokens={completion_tokens}, total_tokens={total_tokens}, "
f"response_cost={response_cost}"
)

except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2677,6 +2677,7 @@ class SpendLogsPayload(TypedDict):
model_id: Optional[str]
model_group: Optional[str]
mcp_namespaced_tool_name: Optional[str]
agent_id: Optional[str]
api_base: str
user: str
metadata: str # json str
Expand Down
10 changes: 10 additions & 0 deletions litellm/proxy/agent_endpoints/a2a_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ async def _handle_stream_message(
request_id: str,
params: dict,
litellm_params: Optional[dict] = None,
agent_id: Optional[str] = None,
metadata: Optional[dict] = None,
proxy_server_request: Optional[dict] = None,
) -> StreamingResponse:
"""Handle message/stream method via SDK functions."""
from a2a.types import MessageSendParams, SendStreamingMessageRequest
Expand All @@ -66,6 +69,9 @@ async def stream_response():
request=a2a_request,
api_base=api_base,
litellm_params=litellm_params,
agent_id=agent_id,
metadata=metadata,
proxy_server_request=proxy_server_request,
):
# Chunk may be dict or object depending on bridge vs standard path
if hasattr(chunk, "model_dump"):
Expand Down Expand Up @@ -241,6 +247,7 @@ async def invoke_agent_a2a(
request=a2a_request,
api_base=agent_url,
litellm_params=litellm_params,
agent_id=agent.agent_id,
metadata=data.get("metadata", {}),
proxy_server_request=data.get("proxy_server_request"),
)
Expand All @@ -252,6 +259,9 @@ async def invoke_agent_a2a(
request_id=request_id,
params=params,
litellm_params=litellm_params,
agent_id=agent.agent_id,
metadata=data.get("metadata", {}),
proxy_server_request=data.get("proxy_server_request"),
)
else:
return _jsonrpc_error(request_id, -32601, f"Method '{method}' not found")
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ model LiteLLM_SpendLogs {
session_id String?
status String?
mcp_namespaced_tool_name String?
agent_id String?
proxy_server_request Json? @default("{}")
@@index([startTime])
@@index([end_user])
Expand Down
13 changes: 10 additions & 3 deletions litellm/proxy/spend_tracking/spend_tracking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,16 @@ def get_logging_payload( # noqa: PLR0915
response_obj_dict = {}

# Handle OCR responses which use usage_info instead of usage
usage: dict = {}
if call_type in ["ocr", "aocr"]:
usage = _extract_usage_for_ocr_call(response_obj, response_obj_dict)
else:
# Use response_obj_dict instead of response_obj to avoid calling .get() on Pydantic models
usage = response_obj_dict.get("usage", None) or {}
if isinstance(usage, litellm.Usage):
usage = dict(usage)
_usage = response_obj_dict.get("usage", None) or {}
if isinstance(_usage, litellm.Usage):
usage = dict(_usage)
elif isinstance(_usage, dict):
usage = _usage

id = get_spend_logs_id(call_type or "acompletion", response_obj_dict, kwargs)
standard_logging_payload = cast(
Expand Down Expand Up @@ -369,6 +372,9 @@ def get_logging_payload( # noqa: PLR0915
"namespaced_tool_name", None
)

# Extract agent_id for A2A requests (set directly on model_call_details)
agent_id: Optional[str] = kwargs.get("agent_id")

try:
payload: SpendLogsPayload = SpendLogsPayload(
request_id=str(id),
Expand Down Expand Up @@ -396,6 +402,7 @@ def get_logging_payload( # noqa: PLR0915
model_group=_model_group,
model_id=_model_id,
mcp_namespaced_tool_name=mcp_namespaced_tool_name,
agent_id=agent_id,
requester_ip_address=clean_metadata.get("requester_ip_address", None),
custom_llm_provider=kwargs.get("custom_llm_provider", ""),
messages=_get_messages_for_spend_logs_payload(
Expand Down
1 change: 1 addition & 0 deletions schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ model LiteLLM_SpendLogs {
session_id String?
status String?
mcp_namespaced_tool_name String?
agent_id String?
proxy_server_request Json? @default("{}")
@@index([startTime])
@@index([end_user])
Expand Down
Loading
Loading