Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Moved the message conversion from str to list[Message] to call only
  • Loading branch information
maykcaldas committed Jul 19, 2025
commit e699df49f875cb6570ed55a8f6ed9bceb0eb0702
42 changes: 18 additions & 24 deletions packages/lmi/src/lmi/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from enum import StrEnum
from inspect import isasyncgenfunction, isawaitable, signature
from typing import Any, ClassVar, ParamSpec, TypeAlias, cast, overload
from typing import Any, ClassVar, Literal, ParamSpec, TypeAlias, cast, overload

import litellm
from aviary.core import (
Expand Down Expand Up @@ -235,12 +235,12 @@ async def call(
output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = ...,
tools: list[Tool] | None = ...,
tool_choice: Tool | str | None = ...,
stream: bool = False,
stream: Literal[False] = ...,
**kwargs,
) -> list[LLMResult]: ...

@overload
async def call( # type: ignore[overload-cannot-match]
async def call(
self,
messages: list[Message] | str,
callbacks: (
Expand Down Expand Up @@ -280,8 +280,8 @@ async def call( # noqa: C901, PLR0915
kwargs: Additional keyword arguments for the chat completion.

Returns:
A list of LLMResult objects containing the result of the call when stream=False,
or an AsyncGenerator[LLMResult] when stream=True.
When not streaming, it's a list of result objects for each call,
otherwise it's an async generator of result objects.

Raises:
ValueError: If the LLM type is unknown.
Expand All @@ -294,7 +294,6 @@ async def call( # noqa: C901, PLR0915
# there may be a nested 'config' key
# that can't be used by chat
chat_kwargs.pop("config", None)
chat_kwargs.pop("stream", None)
n = chat_kwargs.get("n") or self.config.get("n", 1)
if n < 1:
raise ValueError("Number of completions (n) must be >= 1.")
Expand Down Expand Up @@ -426,12 +425,12 @@ async def call_single(
output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = ...,
tools: list[Tool] | None = ...,
tool_choice: Tool | str | None = ...,
stream: bool = False,
stream: Literal[False] = ...,
**kwargs,
) -> LLMResult: ...

@overload
async def call_single( # type: ignore[overload-cannot-match]
async def call_single(
self,
messages: list[Message] | str,
callbacks: (
Expand All @@ -441,7 +440,7 @@ async def call_single( # type: ignore[overload-cannot-match]
output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = ...,
tools: list[Tool] | None = ...,
tool_choice: Tool | str | None = ...,
stream: bool = True,
stream: Literal[True] = ...,
**kwargs,
) -> AsyncGenerator[LLMResult]: ...

Expand All @@ -458,9 +457,6 @@ async def call_single(
stream: bool = False,
**kwargs,
) -> LLMResult | AsyncGenerator[LLMResult]:
if isinstance(messages, str):
# convenience for single message
messages = [Message(content=messages)]
kwargs = {**kwargs, "n": 1}
results = await self.call(
messages,
Expand Down Expand Up @@ -819,15 +815,14 @@ async def acompletion_iter(
role = None
reasoning_content = []
used_model = None
first_token_time = None

async for completion in stream_completions:
if not used_model:
used_model = completion.model or self.name
choice = completion.choices[0]
delta = choice.delta

if first_token_time is None and delta.content:
if delta.content:
seconds_to_first_token = asyncio.get_running_loop().time() - start_clock

if logprob_content := getattr(choice.logprobs, "content", None):
Expand All @@ -837,16 +832,15 @@ async def acompletion_iter(
role = delta.role or role
if hasattr(delta, "reasoning_content"):
reasoning_content.append(delta.reasoning_content or "")

yield LLMResult(
model=used_model,
text=delta.content,
prompt=messages,
messages=[Message(role=role, content=accumulated_text)],
logprob=sum_logprobs(logprobs),
reasoning_content="".join(reasoning_content),
seconds_to_first_token=seconds_to_first_token,
)
yield LLMResult(
model=used_model,
text=accumulated_text,
prompt=messages,
messages=[Message(role=role, content=accumulated_text)],
logprob=sum_logprobs(logprobs),
reasoning_content="".join(reasoning_content),
seconds_to_first_token=seconds_to_first_token,
)

def count_tokens(self, text: str) -> int:
return litellm.token_counter(model=self.name, text=text)
Expand Down
2 changes: 1 addition & 1 deletion packages/lmi/tests/test_cost_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async def ac(x) -> None:
with assert_costs_increased():
await llm.call(messages, [ac])

@pytest.mark.vcr(match_on=[*VCR_DEFAULT_MATCH_ON, "body"])
# @pytest.mark.vcr(match_on=[*VCR_DEFAULT_MATCH_ON, "body"])
@pytest.mark.parametrize(
"config",
[
Expand Down
73 changes: 51 additions & 22 deletions packages/lmi/tests/test_llms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pathlib
import pickle
from collections.abc import AsyncIterable, AsyncIterator
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
from typing import Any, ClassVar
from unittest.mock import Mock, patch

Expand Down Expand Up @@ -323,11 +323,22 @@ async def ac(x) -> None:

with subtests.test(msg="autowraps message"):

def mock_call(messages, *_, **__):
def mock_acompletion(messages, *_, **__):
assert isinstance(messages, list)
return [None]
assert len(messages) == 1
assert isinstance(messages[0], Message)
assert messages[0].content == "Test message"
return [
LLMResult(
model="test-model",
messages=messages,
text="Test result",
)
]

with patch.object(LiteLLMModel, "call", side_effect=mock_call):
with patch.object(
LiteLLMModel, "acompletion", side_effect=mock_acompletion
):
await llm.call_single("Test message")

@pytest.mark.vcr
Expand Down Expand Up @@ -451,12 +462,7 @@ def _build_mock_completion(

# Mock completion with valid logprobs
mock_completion_valid = _build_mock_completion(
logprobs=Mock(content=[Mock(logprob=-0.5)])
)

# Mock completion with usage info
mock_completion_usage = _build_mock_completion(
usage=Mock(prompt_tokens=10, completion_tokens=5)
logprobs=Mock(content=[Mock(logprob=-0.5)], delta_content="")
)

# Create async generator that yields mock completions
Expand All @@ -466,7 +472,6 @@ async def mock_stream_iter(): # noqa: RUF029
yield mock_completion_no_content
yield mock_completion_empty
yield mock_completion_valid
yield mock_completion_usage

return mock_stream_iter()

Expand All @@ -477,14 +482,12 @@ async def mock_stream_iter(): # noqa: RUF029
results = [result async for result in async_iterable]

# Verify we got one final result
assert len(results) == 1
result = results[0]
assert len(results) > 1
result = results[-1]
assert isinstance(result, LLMResult)
assert result.text == "Hello world!"
assert result.model == "test-model"
assert result.logprob == -0.5
assert result.prompt_count == 10
assert result.completion_count == 5


class DummyOutputSchema(BaseModel):
Expand Down Expand Up @@ -565,10 +568,13 @@ async def test_model(self, model_name: str) -> None:

@pytest.mark.parametrize(
"model_name",
[CommonLLMNames.ANTHROPIC_TEST.value, CommonLLMNames.GPT_35_TURBO.value],
[
pytest.param(CommonLLMNames.ANTHROPIC_TEST.value, id="anthropic"),
pytest.param(CommonLLMNames.GPT_35_TURBO.value, id="openai"),
],
)
@pytest.mark.asyncio
async def test_streaming(self, model_name: str) -> None:
async def test_streaming(self, model_name: str, subtests) -> None:
model = self.MODEL_CLS(name=model_name, config=self.DEFAULT_CONFIG)
messages = [
Message(role="system", content="Respond with single words."),
Expand All @@ -578,11 +584,34 @@ async def test_streaming(self, model_name: str) -> None:
def callback(_) -> None:
return

with pytest.raises(
NotImplementedError,
match="Multiple completions with callbacks is not supported",
with (
subtests.test(name="n=2"),
pytest.raises(
ValueError,
match=r"\(n\) must be 1",
),
):
await self.call_model(model, messages, [callback])
await self.call_model(model, messages, stream=True, callbacks=[callback])

config = self.DEFAULT_CONFIG.copy()
config["n"] = 1
model = self.MODEL_CLS(name=model_name, config=config)
with (
subtests.test(name="with callbacks"),
pytest.raises(
NotImplementedError,
match="callbacks with streaming is not supported",
),
):
await self.call_model(model, messages, stream=True, callbacks=[callback])

with subtests.test(name="n=1"):
results = await self.call_model(model, messages, stream=True)
assert isinstance(results, AsyncGenerator)
async for result in results:
assert isinstance(result, LLMResult)
assert result.messages
assert len(result.messages) == 1

@pytest.mark.vcr
@pytest.mark.asyncio
Expand Down Expand Up @@ -688,7 +717,7 @@ async def test_text_image_message(self, model_name: str) -> None:
)
@pytest.mark.asyncio
@pytest.mark.vcr
async def test_single_completion(self, model_name: str) -> None:
async def test_single_completion(self, model_name: str, subtests) -> None:
model = self.MODEL_CLS(name=model_name, config={"n": 1})
messages = [
Message(role="system", content="Respond with single words."),
Expand Down
Loading