Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test_asend_message_streaming_triggers_callbacks
  • Loading branch information
ishaan-jaff committed Dec 10, 2025
commit 1c4358bb1fc0a3adc41cb33a03b85b16968d6d0b
52 changes: 52 additions & 0 deletions tests/test_litellm/a2a_protocol/test_cost_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,55 @@ async def test_asend_message_passes_agent_id_to_callback():

# Verify agent_id was passed to callback
assert agent_id_logger.agent_id == test_agent_id, f"Expected agent_id '{test_agent_id}', got '{agent_id_logger.agent_id}'"


@pytest.mark.asyncio
async def test_asend_message_streaming_triggers_callbacks():
"""
Test that asend_message_streaming triggers callbacks after stream completes.
"""
from litellm.a2a_protocol import asend_message_streaming

# Setup logger - must use logging_callback_manager to properly register
litellm.logging_callback_manager._reset_all_callbacks()
callback_logger = AgentIdLogger()
litellm.logging_callback_manager.add_litellm_async_success_callback(callback_logger)
litellm.logging_callback_manager.add_litellm_success_callback(callback_logger)

# Mock A2A client
mock_client = MagicMock()
mock_client._litellm_agent_card = MagicMock()
mock_client._litellm_agent_card.name = "test-agent"

# Mock streaming response
async def mock_stream():
yield MagicMock(model_dump=lambda mode, exclude_none: {"chunk": 1})
yield MagicMock(model_dump=lambda mode, exclude_none: {"chunk": 2})

mock_client.send_message_streaming = MagicMock(return_value=mock_stream())

# Mock request
mock_request = MagicMock()
mock_request.id = "test-stream-123"
mock_request.params = MagicMock()
mock_request.params.message = {"role": "user", "parts": [{"kind": "text", "text": "Hello"}]}

test_agent_id = "test-agent-id-streaming"

# Consume streaming response
chunks = []
async for chunk in asend_message_streaming(
a2a_client=mock_client,
request=mock_request,
agent_id=test_agent_id,
):
chunks.append(chunk)

await asyncio.sleep(0.2)

# Verify chunks were received
assert len(chunks) == 2

# Verify callbacks WERE triggered after stream completed
assert callback_logger.kwargs is not None, "Streaming should trigger callbacks after completion"
assert callback_logger.agent_id == test_agent_id, f"Expected agent_id '{test_agent_id}', got '{callback_logger.agent_id}'"