Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 82 additions & 50 deletions src/paperqa/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import re
from collections.abc import Callable, Sequence
from typing import Any
from typing import Any, ClassVar

import litellm
from aviary.core import Message
Expand Down Expand Up @@ -124,14 +124,42 @@ def fraction_replacer(match: re.Match) -> str:
return data


class LLMBadContextJSONError(ValueError):
"""Retryable exception for when the LLM gives back bad JSON."""
class LLMContextError(ValueError):
retryable: ClassVar[bool]
help_message: ClassVar[str] # Eventually passed to logger.exception

def __init__(self, message: str, llm_results: list[LLMResult]) -> None:
super().__init__(message)
self.llm_results = llm_results # House so we can cost track across retries


class LLMBadContextJSONError(LLMContextError):
"""Retryable exception for when the LLM gives back bad JSON."""

retryable = True
help_message = (
"Abandoning this context creation."
" Your model may not be capable of supporting JSON output"
" or our parsing technique could use some work. Try"
" a different model or specify `Settings(prompts={'use_json': False})`."
" Or, feel free to just ignore this message, as many contexts are"
" concurrently made and we're not attached to any one given context."
)


class LLMContextTimeoutError(LLMContextError):
"""Non-retryable exception for when the LLM call times out."""

retryable = False
help_message = (
"Timeout when creating a context, abandoning it."
" If you see this error frequently, consider increasing the timeout in"
" Settings(summary_llm_config=...). Or, feel free to just ignore this message,"
" as many contexts are concurrently made and we're not attached to any one"
" given context."
)


async def _map_fxn_summary( # noqa: PLR0912
text: Text,
question: str,
Expand All @@ -142,7 +170,7 @@ async def _map_fxn_summary( # noqa: PLR0912
callbacks: Sequence[Callable[[str], None]] | None = None,
skip_citation_strip: bool = False,
evidence_text_only_fallback: bool = False,
_prior_attempt: LLMBadContextJSONError | None = None,
_prior_attempt: LLMContextError | None = None,
) -> tuple[Context, list[LLMResult]]:
"""Parses the given text and returns a context object with the parser and prompt runner.

Expand Down Expand Up @@ -204,41 +232,48 @@ async def _map_fxn_summary( # noqa: PLR0912
} | (extra_prompt_data or {})
message_prompt, system_prompt = (pt.format(**data) for pt in prompt_templates)
try:
llm_result = await summary_llm_model.call_single(
messages=[
Message(role="system", content=system_prompt),
Message.create_message(
text=message_prompt,
images=(
[i.to_image_url() for i in text.media]
if text.media
else None
try:
llm_result = await summary_llm_model.call_single(
messages=[
Message(role="system", content=system_prompt),
Message.create_message(
text=message_prompt,
images=(
[i.to_image_url() for i in text.media]
if text.media
else None
),
),
),
*append_msgs,
],
callbacks=callbacks,
name="evidence:" + text.name,
)
except litellm.BadRequestError as exc:
if not evidence_text_only_fallback:
raise
logger.warning(
f"LLM call to create a context failed with exception {exc!r}"
f" on text named {text.name!r}"
f" with doc name {text.doc.docname!r} and doc key {text.doc.dockey!r}."
f" Retrying without media."
)
llm_result = await summary_llm_model.call_single(
messages=[
Message(role="system", content=system_prompt),
Message(content=message_prompt),
*append_msgs,
],
callbacks=callbacks,
name="evidence:" + text.name,
)
used_text_only_fallback = True
*append_msgs,
],
callbacks=callbacks,
name="evidence:" + text.name,
)
except litellm.BadRequestError as exc:
if not evidence_text_only_fallback:
raise
logger.warning(
f"LLM call to create a context failed with exception {exc!r}"
f" on text named {text.name!r}"
f" with doc name {text.doc.docname!r} and doc key {text.doc.dockey!r}."
f" Retrying without media."
)
llm_result = await summary_llm_model.call_single(
messages=[
Message(role="system", content=system_prompt),
Message(content=message_prompt),
*append_msgs,
],
callbacks=callbacks,
name="evidence:" + text.name,
)
used_text_only_fallback = True
except litellm.Timeout as exc:
raise LLMContextTimeoutError(
f"LLM call to create a context timed out on text named {text.name!r}.",
llm_results=llm_results,
) from exc

llm_results.append(llm_result)
context = llm_result.text or ""
try:
Expand Down Expand Up @@ -281,8 +316,7 @@ async def _map_fxn_summary( # noqa: PLR0912
score = extract_score(context)
except ValueError as exc:
raise LLMBadContextJSONError(
f"Extracting score from raw context {context!r}"
f" failed due to: {exc}",
f"Extracting score from raw context {context!r} failed due to: {exc}",
llm_results=llm_results,
) from exc
else:
Expand Down Expand Up @@ -325,16 +359,14 @@ async def map_fxn_summary(**kwargs) -> tuple[Context | None, list[LLMResult]]:
)
try:
return await _map_fxn_summary(**kwargs)
except LLMBadContextJSONError as exc:
try:
return await _map_fxn_summary(**kwargs, _prior_attempt=exc)
except LLMBadContextJSONError as exc2:
except LLMContextError as exc:
if not exc.retryable:
logger.exception(
"Failed twice to create a context, abandoning it."
" Your model may not be capable of supporting JSON output"
" or our parsing technique could use some work. Try"
" a different model or specify `Settings(prompts={'use_json': False})`."
" Or, feel free to just ignore this message, as many contexts are"
" concurrently made and we're not attached to any one given context."
"Non-retryable failure creating a context. %s", exc.help_message
)
return None, exc.llm_results
try:
return await _map_fxn_summary(**kwargs, _prior_attempt=exc)
except LLMContextError as exc2:
logger.exception("Failed twice to create a context. %s", exc2.help_message)
return None, exc2.llm_results
82 changes: 65 additions & 17 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
EmbeddingModel,
HybridEmbeddingModel,
LiteLLMEmbeddingModel,
LiteLLMModel,
LLMModel,
LLMResult,
SparseEmbeddingModel,
Expand All @@ -53,7 +54,12 @@
)
from paperqa.clients import CrossrefProvider
from paperqa.clients.journal_quality import JournalQualityPostProcessor
from paperqa.core import llm_parse_json
from paperqa.core import (
LLMContextTimeoutError,
_map_fxn_summary, # noqa: PLC2701
llm_parse_json,
map_fxn_summary,
)
from paperqa.prompts import CANNOT_ANSWER_PHRASE
from paperqa.prompts import qa_prompt as default_qa_prompt
from paperqa.readers import PDFParserFn, parse_image, read_doc
Expand Down Expand Up @@ -104,10 +110,7 @@ def test_multiple_authors() -> None:


def test_multiple_citations() -> None:
text = (
"As discussed by several authors (Smith et al. 1999; Johnson 2001; Lee et al."
" 2003)."
)
text = "As discussed by several authors (Smith et al. 1999; Johnson 2001; Lee et al. 2003)."
assert strip_citations(text) == "As discussed by several authors ."


Expand Down Expand Up @@ -620,7 +623,6 @@ async def test_query(docs_fixture) -> None:

@pytest.mark.asyncio
async def test_custom_context_str_fn(docs_fixture) -> None:

async def custom_context_str_fn( # noqa: RUF029
settings: Settings, # noqa: ARG001
contexts: list[Context], # noqa: ARG001
Expand All @@ -646,7 +648,6 @@ async def custom_context_str_fn( # noqa: RUF029

@pytest.mark.asyncio
async def test_aquery_groups_contexts_by_question(docs_fixture) -> None:

session = PQASession(question="What is the relationship between chemistry and AI?")

doc = Doc(docname="test_doc", citation="Test Doc, 2025", dockey="key1")
Expand Down Expand Up @@ -733,12 +734,14 @@ async def test_query_with_iteration(docs_fixture) -> None:
prior_session = PQASession(question=question, answer=prior_answer)
await docs_fixture.aquery(prior_session, llm_model=llm, settings=settings)
assert prior_answer in cast(
"str", my_results[-1].prompt[1].content # type: ignore[union-attr, index]
"str",
my_results[-1].prompt[1].content, # type: ignore[union-attr, index]
), "prior answer not in prompt"
# run without a prior session to check that the flow works correctly
await docs_fixture.aquery(question, llm_model=llm, settings=settings)
assert settings.prompts.answer_iteration_prompt[:10] not in cast( # type: ignore[index]
"str", my_results[-1].prompt[1].content # type: ignore[union-attr, index]
"str",
my_results[-1].prompt[1].content, # type: ignore[union-attr, index]
), "prior answer prompt should not be inserted"


Expand Down Expand Up @@ -962,7 +965,9 @@ class StubLLMModel(LLMModel):
name: str = "custom/myllm"

async def acompletion(
self, messages: list[Message], **kwargs # noqa: ARG002
self,
messages: list[Message],
**kwargs, # noqa: ARG002
) -> list[LLMResult]:
return [
LLMResult(
Expand All @@ -976,7 +981,9 @@ async def acompletion(

@rate_limited
async def acompletion_iter(
self, messages: list[Message], **kwargs # noqa: ARG002
self,
messages: list[Message],
**kwargs, # noqa: ARG002
) -> AsyncIterable[LLMResult]:
yield LLMResult(
model=self.name,
Expand Down Expand Up @@ -1708,10 +1715,7 @@ async def test_custom_prompts(stub_data_dir: Path) -> None:

@pytest.mark.asyncio
async def test_pre_prompt(stub_data_dir: Path) -> None:
pre = (
"What is water's boiling point in Fahrenheit? Please respond with a complete"
" sentence."
)
pre = "What is water's boiling point in Fahrenheit? Please respond with a complete sentence."

settings = Settings.from_name("fast")
settings.prompts.pre = pre
Expand Down Expand Up @@ -2439,8 +2443,7 @@ def test_llm_parse_json_with_escaped_characters(self, input_text, expected_outpu
def test_llm_subquotes_and_newlines(self, input_text: str) -> None:
output = {
"summary": (
'An excerpt with "quoted stuff" or "maybe more." More stuff (with'
" parenthesis)."
'An excerpt with "quoted stuff" or "maybe more." More stuff (with parenthesis).'
),
"relevance_score": 8,
}
Expand Down Expand Up @@ -2616,3 +2619,48 @@ def test_pqa_context_id_parsing(raw_text: str, cleaned_text: str) -> None:
)
session.populate_formatted_answers_and_bib_from_raw_answer()
assert session.answer == cleaned_text


@pytest.mark.asyncio
async def test_timeout_resilience() -> None:
model_name = CommonLLMNames.ANTHROPIC_TEST.value
short_timeout = 0.001
llm = LiteLLMModel(
name=CommonLLMNames.ANTHROPIC_TEST.value,
config={
"model_list": [
{
"model_name": model_name,
"litellm_params": {
"model": model_name,
"timeout": short_timeout,
},
}
],
"router_kwargs": {"timeout": short_timeout},
},
)

# Make sure we've configured timeout low enough for this test to be useful
with pytest.raises(litellm.Timeout):
await llm.call_single("The duck says")

text = Text(
text="The duck says",
name="test",
doc=Doc(docname="test", dockey="test", citation="test"),
)
kw = {
"text": text,
"question": "The duck says",
"summary_llm_model": llm,
"prompt_templates": ("", ""),
}
# This *should* raise
with pytest.raises(LLMContextTimeoutError):
await _map_fxn_summary(**kw) # type: ignore[arg-type]

# The wrapped version should not raise, but return empty
context, llm_results = await map_fxn_summary(**kw)
assert context is None
assert not llm_results