diff --git a/packages/lmi/src/lmi/llms.py b/packages/lmi/src/lmi/llms.py index 97237014..cab58838 100644 --- a/packages/lmi/src/lmi/llms.py +++ b/packages/lmi/src/lmi/llms.py @@ -17,7 +17,7 @@ import logging from abc import ABC from collections.abc import ( - AsyncIterable, + AsyncGenerator, Awaitable, Callable, Coroutine, @@ -27,7 +27,7 @@ ) from enum import StrEnum from inspect import isasyncgenfunction, isawaitable, signature -from typing import Any, ClassVar, ParamSpec, TypeAlias, cast, overload +from typing import Any, ClassVar, Literal, ParamSpec, TypeAlias, cast, overload import litellm from aviary.core import ( @@ -201,8 +201,8 @@ async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult async def acompletion_iter( self, messages: list[Message], **kwargs - ) -> AsyncIterable[LLMResult]: - """Return an async generator that yields completions. + ) -> AsyncGenerator[LLMResult]: + """Return an async generator that `yield`s completions. Only the last tuple will be non-zero. """ @@ -224,9 +224,39 @@ def __str__(self) -> str: # None means we won't provide a tool_choice to the LLM API UNSPECIFIED_TOOL_CHOICE: ClassVar[None] = None + @overload + async def call( + self, + messages: list[Message] | str, + callbacks: ( + Sequence[Callable[..., Any] | Callable[..., Awaitable]] | None + ) = ..., + name: str | None = ..., + output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = ..., + tools: list[Tool] | None = ..., + tool_choice: Tool | str | None = ..., + stream: Literal[False] = ..., + **kwargs, + ) -> list[LLMResult]: ... + + @overload + async def call( + self, + messages: list[Message] | str, + callbacks: ( + Sequence[Callable[..., Any] | Callable[..., Awaitable]] | None + ) = ..., + name: str | None = ..., + output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = ..., + tools: list[Tool] | None = ..., + tool_choice: Tool | str | None = ..., + stream: bool = True, + **kwargs, + ) -> AsyncGenerator[LLMResult]: ... + async def call( # noqa: C901, PLR0915 self, - messages: list[Message], + messages: list[Message] | str, callbacks: ( Sequence[Callable[..., Any] | Callable[..., Awaitable]] | None ) = None, @@ -234,8 +264,9 @@ async def call( # noqa: C901, PLR0915 output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = None, tools: list[Tool] | None = None, tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED, + stream: bool = False, **kwargs, - ) -> list[LLMResult]: + ) -> list[LLMResult] | AsyncGenerator[LLMResult]: """Call the LLM model with the given messages and configuration. Args: @@ -245,14 +276,19 @@ async def call( # noqa: C901, PLR0915 output_type: The type of the output. tools: A list of tools to use. tool_choice: The tool choice to use. + stream: Whether to stream the response or return all results at once. kwargs: Additional keyword arguments for the chat completion. Returns: - A list of LLMResult objects containing the result of the call. + When not streaming, it's a list of result objects for each call, + otherwise it's an async generator of result objects. Raises: ValueError: If the LLM type is unknown. """ + if isinstance(messages, str): + # convenience for single message + messages = [Message(content=messages)] chat_kwargs = copy.deepcopy(kwargs) # if using the config for an LLMModel, # there may be a nested 'config' key @@ -261,6 +297,8 @@ async def call( # noqa: C901, PLR0915 n = chat_kwargs.get("n") or self.config.get("n", 1) if n < 1: raise ValueError("Number of completions (n) must be >= 1.") + if stream and n > 1: + raise ValueError("Number of completions (n) must be 1 when streaming.") if "fallbacks" not in chat_kwargs and "fallbacks" in self.config: chat_kwargs["fallbacks"] = self.config.get("fallbacks", []) @@ -328,48 +366,83 @@ async def call( # noqa: C901, PLR0915 ) for m in messages ] - results: list[LLMResult] = [] start_clock = asyncio.get_running_loop().time() - if callbacks is None: + + # If not streaming, simply return the results + if not stream: + sync_callbacks = [ + f for f in (callbacks or []) if not is_coroutine_callable(f) + ] + async_callbacks = [f for f in (callbacks or []) if is_coroutine_callable(f)] results = await self.acompletion(messages, **chat_kwargs) - else: - if tools: - raise NotImplementedError("Using tools with callbacks is not supported") - n = chat_kwargs.get("n") or self.config.get("n", 1) - if n > 1: - raise NotImplementedError( - "Multiple completions with callbacks is not supported" + for result in results: + text = cast("str", result.text) + await do_callbacks(async_callbacks, sync_callbacks, text, name) + usage = result.prompt_count, result.completion_count + if not sum(usage): + result.completion_count = self.count_tokens(text) + result.seconds_to_last_token = ( + asyncio.get_running_loop().time() - start_clock ) - sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)] - async_callbacks = [f for f in callbacks if is_coroutine_callable(f)] - stream_results = await self.acompletion_iter(messages, **chat_kwargs) - text_result = [] - async for result in stream_results: - if result.text: - if result.seconds_to_first_token == 0: - result.seconds_to_first_token = ( - asyncio.get_running_loop().time() - start_clock - ) - text_result.append(result.text) - await do_callbacks( - async_callbacks, sync_callbacks, result.text, name + result.name = name + if self.llm_result_callback: + possibly_awaitable_result = self.llm_result_callback(result) + if isawaitable(possibly_awaitable_result): + await possibly_awaitable_result + return results + + # If streaming, return an AsyncGenerator[LLMResult] + if tools: + raise NotImplementedError("Using tools with streaming is not supported") + if callbacks: + raise NotImplementedError("Using callbacks with streaming is not supported") + + async def process_stream() -> AsyncGenerator[LLMResult]: + async_iterable = await self.acompletion_iter(messages, **chat_kwargs) + async for result in async_iterable: + usage = result.prompt_count, result.completion_count + if not sum(usage): + result.completion_count = self.count_tokens( + cast("str", result.text) ) - results.append(result) - - for result in results: - usage = result.prompt_count, result.completion_count - if not sum(usage): - result.completion_count = self.count_tokens(cast("str", result.text)) - result.seconds_to_last_token = ( - asyncio.get_running_loop().time() - start_clock - ) - result.name = name - if self.llm_result_callback: - possibly_awaitable_result = self.llm_result_callback(result) - if isawaitable(possibly_awaitable_result): - await possibly_awaitable_result - return results + result.seconds_to_last_token = ( + asyncio.get_running_loop().time() - start_clock + ) + result.name = name + yield result + + return process_stream() + + @overload + async def call_single( + self, + messages: list[Message] | str, + callbacks: ( + Sequence[Callable[..., Any] | Callable[..., Awaitable]] | None + ) = ..., + name: str | None = ..., + output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = ..., + tools: list[Tool] | None = ..., + tool_choice: Tool | str | None = ..., + stream: Literal[False] = ..., + **kwargs, + ) -> LLMResult: ... + + @overload + async def call_single( + self, + messages: list[Message] | str, + callbacks: ( + Sequence[Callable[..., Any] | Callable[..., Awaitable]] | None + ) = ..., + name: str | None = ..., + output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = ..., + tools: list[Tool] | None = ..., + tool_choice: Tool | str | None = ..., + stream: Literal[True] = ..., + **kwargs, + ) -> AsyncGenerator[LLMResult]: ... async def call_single( self, @@ -381,11 +454,10 @@ async def call_single( output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = None, tools: list[Tool] | None = None, tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED, + stream: bool = False, **kwargs, - ) -> LLMResult: - if isinstance(messages, str): - # convenience for single message - messages = [Message(content=messages)] + ) -> LLMResult | AsyncGenerator[LLMResult]: + kwargs = kwargs | {"n": 1} results = await self.call( messages, callbacks, @@ -393,9 +465,17 @@ async def call_single( output_type, tools, tool_choice, - n=1, + stream, **kwargs, ) + + if stream: + if not isinstance(results, AsyncGenerator): + raise TypeError("Expected AsyncGenerator of results when streaming") + return results + + if not isinstance(results, list): + raise TypeError("Expected list of results when not streaming") if len(results) != 1: # Can be caused by issues like https://github.com/BerriAI/litellm/issues/12298 raise ValueError(f"Got {len(results)} results when expecting just one.") @@ -413,8 +493,8 @@ def rate_limited( @overload def rate_limited( - func: Callable[P, AsyncIterable[LLMResult]], -) -> Callable[P, Coroutine[Any, Any, AsyncIterable[LLMResult]]]: ... + func: Callable[P, AsyncGenerator[LLMResult]], +) -> Callable[P, Coroutine[Any, Any, AsyncGenerator[LLMResult]]]: ... def rate_limited(func): @@ -440,7 +520,7 @@ async def wrapper(self, *args, **kwargs): # portion before yielding if isasyncgenfunction(func): - async def rate_limited_generator() -> AsyncIterable[LLMResult]: + async def rate_limited_generator() -> AsyncGenerator[LLMResult]: async for item in func(self, *args, **kwargs): token_count = 0 if isinstance(item, LLMResult): @@ -469,8 +549,8 @@ def request_limited( @overload def request_limited( - func: Callable[P, Coroutine[Any, Any, AsyncIterable[LLMResult]]], -) -> Callable[P, Coroutine[Any, Any, AsyncIterable[LLMResult]]]: ... + func: Callable[P, Coroutine[Any, Any, AsyncGenerator[LLMResult]]], +) -> Callable[P, Coroutine[Any, Any, AsyncGenerator[LLMResult]]]: ... def request_limited(func): @@ -487,7 +567,7 @@ async def wrapper(self, *args, **kwargs): if isasyncgenfunction(func): - async def request_limited_generator() -> AsyncIterable[LLMResult]: + async def request_limited_generator() -> AsyncGenerator[LLMResult]: first_item = True async for item in func(self, *args, **kwargs): # Skip rate limit check for first item since we already checked at generator start @@ -608,16 +688,6 @@ def maybe_set_config_attribute(cls, input_data: dict[str, Any]) -> dict[str, Any _DeploymentTypedDictValidator.validate_python(model_list) return data - # SEE: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice - # > `none` means the model will not call any tool and instead generates a message. - # > `auto` means the model can pick between generating a message or calling one or more tools. - # > `required` means the model must call one or more tools. - NO_TOOL_CHOICE: ClassVar[str] = "none" - MODEL_CHOOSES_TOOL: ClassVar[str] = "auto" - TOOL_CHOICE_REQUIRED: ClassVar[str] = "required" - # None means we won't provide a tool_choice to the LLM API - UNSPECIFIED_TOOL_CHOICE: ClassVar[None] = None - def __getstate__(self): # Prevent _router from being pickled, SEE: https://stackoverflow.com/a/2345953 state = super().__getstate__() @@ -719,7 +789,7 @@ async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult @rate_limited async def acompletion_iter( self, messages: list[Message], **kwargs - ) -> AsyncIterable[LLMResult]: + ) -> AsyncGenerator[LLMResult]: # cast is necessary for LiteLLM typing bug: https://github.com/BerriAI/litellm/issues/7641 prompts = cast( "list[litellm.types.llms.openai.AllMessageValues]", @@ -728,7 +798,7 @@ async def acompletion_iter( stream_options = { "include_usage": True, } - # NOTE: Specifically requesting reasoning for deepseek-r1 models + if kwargs.get("include_reasoning"): stream_options["include_reasoning"] = True @@ -740,43 +810,37 @@ async def acompletion_iter( **kwargs, ) start_clock = asyncio.get_running_loop().time() - outputs = [] + accumulated_text = "" logprobs = [] role = None reasoning_content = [] used_model = None + async for completion in stream_completions: if not used_model: used_model = completion.model or self.name choice = completion.choices[0] delta = choice.delta - # logprobs can be None, or missing a content attribute, - # or a ChoiceLogprobs object with a NoneType/empty content attribute + + if delta.content: + seconds_to_first_token = asyncio.get_running_loop().time() - start_clock + if logprob_content := getattr(choice.logprobs, "content", None): logprobs.append(logprob_content[0].logprob or 0) - outputs.append(delta.content or "") - role = delta.role or role - if hasattr(delta, "reasoning_content"): - reasoning_content.append(delta.reasoning_content or "") - text = "".join(outputs) - result = LLMResult( - model=used_model, - text=text, - prompt=messages, - messages=[Message(role=role, content=text)], - logprob=sum_logprobs(logprobs), - reasoning_content="".join(reasoning_content), - ) - - if text: - result.seconds_to_first_token = ( - asyncio.get_running_loop().time() - start_clock + if delta.content: + accumulated_text += delta.content + role = delta.role or role + if hasattr(delta, "reasoning_content"): + reasoning_content.append(delta.reasoning_content or "") + yield LLMResult( + model=used_model, + text=accumulated_text, + prompt=messages, + messages=[Message(role=role, content=accumulated_text)], + logprob=sum_logprobs(logprobs), + reasoning_content="".join(reasoning_content), + seconds_to_first_token=seconds_to_first_token, ) - if hasattr(completion, "usage"): - result.prompt_count = completion.usage.prompt_tokens - result.completion_count = completion.usage.completion_tokens - - yield result def count_tokens(self, text: str) -> int: return litellm.token_counter(model=self.name, text=text) diff --git a/packages/lmi/tests/test_cost_tracking.py b/packages/lmi/tests/test_cost_tracking.py index de49a42a..0e1192b7 100644 --- a/packages/lmi/tests/test_cost_tracking.py +++ b/packages/lmi/tests/test_cost_tracking.py @@ -107,7 +107,7 @@ async def ac(x) -> None: with assert_costs_increased(): await llm.call(messages, [ac]) - @pytest.mark.vcr(match_on=[*VCR_DEFAULT_MATCH_ON, "body"]) + # @pytest.mark.vcr(match_on=[*VCR_DEFAULT_MATCH_ON, "body"]) @pytest.mark.parametrize( "config", [ diff --git a/packages/lmi/tests/test_llms.py b/packages/lmi/tests/test_llms.py index c2d90591..5a31f1ba 100644 --- a/packages/lmi/tests/test_llms.py +++ b/packages/lmi/tests/test_llms.py @@ -1,6 +1,6 @@ import pathlib import pickle -from collections.abc import AsyncIterator +from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator from typing import Any, ClassVar from unittest.mock import Mock, patch @@ -286,6 +286,7 @@ def accum(x) -> None: messages=messages, callbacks=[accum], ) + assert isinstance(completion, LLMResult) assert completion.model == CommonLLMNames.OPENAI_TEST.value assert completion.seconds_to_last_token > 0 assert completion.prompt_count > 0 @@ -296,6 +297,7 @@ def accum(x) -> None: completion = await llm.call_single( messages=messages, ) + assert isinstance(completion, LLMResult) assert completion.seconds_to_last_token > 0 assert completion.cost > 0 @@ -307,6 +309,7 @@ async def ac(x) -> None: messages=messages, callbacks=[accum, ac], ) + assert isinstance(completion, LLMResult) assert completion.cost > 0 with subtests.test(msg="passing-kwargs"): @@ -314,16 +317,28 @@ async def ac(x) -> None: messages=[Message(role="user", content="Tell me a very long story")], max_tokens=1000, ) + assert isinstance(completion, LLMResult) assert completion.cost > 0 assert completion.completion_count > 100, "Expected a long completion" with subtests.test(msg="autowraps message"): - def mock_call(messages, *_, **__): + def mock_acompletion(messages, *_, **__): assert isinstance(messages, list) - return [None] + assert len(messages) == 1 + assert isinstance(messages[0], Message) + assert messages[0].content == "Test message" + return [ + LLMResult( + model="test-model", + messages=messages, + text="Test result", + ) + ] - with patch.object(LiteLLMModel, "call", side_effect=mock_call): + with patch.object( + LiteLLMModel, "acompletion", side_effect=mock_acompletion + ): await llm.call_single("Test message") @pytest.mark.vcr @@ -447,12 +462,7 @@ def _build_mock_completion( # Mock completion with valid logprobs mock_completion_valid = _build_mock_completion( - logprobs=Mock(content=[Mock(logprob=-0.5)]) - ) - - # Mock completion with usage info - mock_completion_usage = _build_mock_completion( - usage=Mock(prompt_tokens=10, completion_tokens=5) + logprobs=Mock(content=[Mock(logprob=-0.5)], delta_content="") ) # Create async generator that yields mock completions @@ -462,7 +472,6 @@ async def mock_stream_iter(): # noqa: RUF029 yield mock_completion_no_content yield mock_completion_empty yield mock_completion_valid - yield mock_completion_usage return mock_stream_iter() @@ -473,14 +482,12 @@ async def mock_stream_iter(): # noqa: RUF029 results = [result async for result in async_iterable] # Verify we got one final result - assert len(results) == 1 - result = results[0] + assert len(results) > 1 + result = results[-1] assert isinstance(result, LLMResult) assert result.text == "Hello world!" assert result.model == "test-model" assert result.logprob == -0.5 - assert result.prompt_count == 10 - assert result.completion_count == 5 class DummyOutputSchema(BaseModel): @@ -498,7 +505,9 @@ class TestMultipleCompletion: DEFAULT_CONFIG: ClassVar[dict] = {"n": NUM_COMPLETIONS} MODEL_CLS: ClassVar[type[LiteLLMModel]] = LiteLLMModel - async def call_model(self, model: LiteLLMModel, *args, **kwargs) -> list[LLMResult]: + async def call_model( + self, model: LiteLLMModel, *args, **kwargs + ) -> list[LLMResult] | AsyncIterable[LLMResult]: return await model.call(*args, **kwargs) @pytest.mark.parametrize( @@ -544,6 +553,7 @@ async def test_model(self, model_name: str) -> None: Message(content="Hello, how are you?"), ] results = await self.call_model(model, messages) + assert isinstance(results, list) assert len(results) == self.NUM_COMPLETIONS for result in results: @@ -558,10 +568,13 @@ async def test_model(self, model_name: str) -> None: @pytest.mark.parametrize( "model_name", - [CommonLLMNames.ANTHROPIC_TEST.value, CommonLLMNames.GPT_35_TURBO.value], + [ + pytest.param(CommonLLMNames.ANTHROPIC_TEST.value, id="anthropic"), + pytest.param(CommonLLMNames.GPT_35_TURBO.value, id="openai"), + ], ) @pytest.mark.asyncio - async def test_streaming(self, model_name: str) -> None: + async def test_streaming(self, model_name: str, subtests) -> None: model = self.MODEL_CLS(name=model_name, config=self.DEFAULT_CONFIG) messages = [ Message(role="system", content="Respond with single words."), @@ -571,11 +584,34 @@ async def test_streaming(self, model_name: str) -> None: def callback(_) -> None: return - with pytest.raises( - NotImplementedError, - match="Multiple completions with callbacks is not supported", + with ( + subtests.test(name="n=2"), + pytest.raises( + ValueError, + match=r"\(n\) must be 1", + ), ): - await self.call_model(model, messages, [callback]) + await self.call_model(model, messages, stream=True, callbacks=[callback]) + + config = self.DEFAULT_CONFIG.copy() + config["n"] = 1 + model = self.MODEL_CLS(name=model_name, config=config) + with ( + subtests.test(name="with callbacks"), + pytest.raises( + NotImplementedError, + match="callbacks with streaming is not supported", + ), + ): + await self.call_model(model, messages, stream=True, callbacks=[callback]) + + with subtests.test(name="n=1"): + results = await self.call_model(model, messages, stream=True) + assert isinstance(results, AsyncGenerator) + async for result in results: + assert isinstance(result, LLMResult) + assert result.messages + assert len(result.messages) == 1 @pytest.mark.vcr @pytest.mark.asyncio @@ -594,6 +630,7 @@ def play(move: int | None) -> None: messages=[Message(content="Please win.")], tools=[Tool.from_function(play)], ) + assert isinstance(results, list) assert len(results) == self.NUM_COMPLETIONS for result in results: assert result.messages @@ -636,6 +673,7 @@ async def test_output_schema( ), ] results = await self.call_model(model, messages, output_type=output_type) + assert isinstance(results, list) assert len(results) == self.NUM_COMPLETIONS for result in results: assert result.messages @@ -662,6 +700,7 @@ async def test_text_image_message(self, model_name: str) -> None: ) ], ) + assert isinstance(results, list) assert len(results) == self.NUM_COMPLETIONS for result in results: assert result.messages is not None, ( @@ -678,7 +717,7 @@ async def test_text_image_message(self, model_name: str) -> None: ) @pytest.mark.asyncio @pytest.mark.vcr - async def test_single_completion(self, model_name: str) -> None: + async def test_single_completion(self, model_name: str, subtests) -> None: model = self.MODEL_CLS(name=model_name, config={"n": 1}) messages = [ Message(role="system", content="Respond with single words."), @@ -687,7 +726,6 @@ async def test_single_completion(self, model_name: str) -> None: result = await model.call_single(messages) assert isinstance(result, LLMResult) - assert isinstance(result, LLMResult) assert result.messages assert len(result.messages) == 1 assert result.messages[0].content @@ -720,10 +758,12 @@ async def test_multiple_completion(self, model_name: str, request) -> None: await model.call(messages) else: results = await model.call(messages) # noqa: FURB120 + assert isinstance(results, list) assert len(results) == self.NUM_COMPLETIONS model = self.MODEL_CLS(name=model_name, config={"n": 5}) results = await model.call(messages, n=self.NUM_COMPLETIONS) + assert isinstance(results, list) assert len(results) == self.NUM_COMPLETIONS @@ -813,6 +853,7 @@ async def test_empty_tools(self, tools: list | None, model_name: str) -> None: tool_choice=LiteLLMModel.MODEL_CHOOSES_TOOL, ) + assert isinstance(result, LLMResult) assert isinstance(result.messages, list) if tools is None: assert isinstance(result.messages[0], Message) @@ -861,11 +902,13 @@ async def test_deepseek_model(self, llm_name: str) -> None: Message(content="What is the meaning of life?"), ] results = await llm.call(messages) + assert isinstance(results, list) for result in results: assert result.reasoning_content outputs: list[str] = [] results = await llm.call(messages, callbacks=[outputs.append]) + assert isinstance(results, list) for i, result in enumerate(results): assert result.reasoning_content @@ -897,6 +940,7 @@ async def test_anthropic_model(self) -> None: Message(content="What is the meaning of life?"), ] results = await llm.call(messages) + assert isinstance(results, list) for result in results: assert result.reasoning_content is not None, "Should have reasoning content" assert result.text is not None