Skip to content
Merged
131 changes: 98 additions & 33 deletions src/paperqa/core.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from __future__ import annotations

import contextlib
import json
import logging
import re
from collections.abc import Callable, Sequence
from typing import Any, cast
from typing import Any

import litellm
from aviary.core import Message
from lmi import LLMModel
from lmi import LLMModel, LLMResult

from paperqa.prompts import text_with_tables_prompt_template
from paperqa.types import Context, LLMResult, Text
from paperqa.types import Context, Text
from paperqa.utils import extract_score, strip_citations

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -108,11 +106,7 @@ def fraction_replacer(match: re.Match) -> str:
"relevance_score": int(score_match.group(1)),
}

raise ValueError(
f"Failed to parse JSON from text {text!r}. Your model may not be capable of"
" supporting JSON output or our parsing technique could use some work. Try"
" a different model or specify `Settings(prompts={'use_json': False})`"
) from e
raise ValueError(f"Failed to load JSON from text {text!r}.") from e

# Handling incorrect key names for "relevance_score"
for key in list(data.keys()):
Expand All @@ -132,7 +126,15 @@ def fraction_replacer(match: re.Match) -> str:
return data


async def map_fxn_summary(
class LLMBadContextJSONError(ValueError):
"""Retryable exception for when the LLM gives back bad JSON."""

def __init__(self, message: str, llm_results: list[LLMResult]) -> None:
super().__init__(message)
self.llm_results = llm_results # House so we can cost track across retries


async def _map_fxn_summary( # noqa: PLR0912
text: Text,
question: str,
summary_llm_model: LLMModel | None,
Expand All @@ -142,7 +144,8 @@ async def map_fxn_summary(
callbacks: Sequence[Callable[[str], None]] | None = None,
skip_citation_strip: bool = False,
evidence_text_only_fallback: bool = False,
) -> tuple[Context, LLMResult]:
_prior_attempt: LLMBadContextJSONError | None = None,
) -> tuple[Context, list[LLMResult]]:
"""Parses the given text and returns a context object with the parser and prompt runner.

The parser should at least return a dict with `summary`. A `relevant_score` will be used and any
Expand All @@ -162,15 +165,25 @@ async def map_fxn_summary(
skip_citation_strip: Optional skipping of citation stripping, if you want to keep in the context.
evidence_text_only_fallback: Opt-in flag to allow retrying context creation
without media in the completion.
_prior_attempt: Optional failure from a prior attempt, for LLM result tracking.

Returns:
The context object and LLMResult to get info about the LLM execution.
A two-tuple of the made Context, and any LLM results made along the way.
"""
# needed empties for failures/skips
llm_result = LLMResult(model="", date="")
if _prior_attempt is not None:
llm_results = _prior_attempt.llm_results
append_msgs = [
Message(
content=(
"In a prior attempt, we failed with this failure message:"
f" {_prior_attempt!s}."
)
)
]
else:
llm_results, append_msgs = [], []
extras: dict[str, Any] = {}
citation = text.name + ": " + text.doc.formatted_citation
success = False
used_text_only_fallback = False

# Strip newlines in case chunking led to blank lines,
Expand Down Expand Up @@ -204,6 +217,7 @@ async def map_fxn_summary(
else None
),
),
*append_msgs,
],
callbacks=callbacks,
name="evidence:" + text.name,
Expand All @@ -221,40 +235,69 @@ async def map_fxn_summary(
messages=[
Message(role="system", content=system_prompt),
Message(content=message_prompt),
*append_msgs,
],
callbacks=callbacks,
name="evidence:" + text.name,
)
used_text_only_fallback = True
context = cast("str", llm_result.text)
result_data = parser(context) if parser else {}
success = bool(result_data)
if success:
llm_results.append(llm_result)
context = llm_result.text or ""
try:
result_data = parser(context) if parser else {}
except ValueError as exc:
raise LLMBadContextJSONError(
f"Failed to parse JSON from context {context!r} due to: {exc}",
llm_results=llm_results,
) from exc
if result_data:
try:
context = result_data.pop("summary")
score = (
result_data.pop("relevance_score")
if "relevance_score" in result_data
else extract_score(context)
)
try:
score = (
result_data.pop("relevance_score")
if "relevance_score" in result_data
else extract_score(context)
)
except ValueError as exc:
raise LLMBadContextJSONError(
f"Successfully parsed JSON and extracted 'summary' key,"
f" but then failed to extract score from context {context!r} due to: {exc}",
llm_results=llm_results,
) from exc
# just in case question was present
result_data.pop("question", None)
extras = result_data
except KeyError:
success = False
except KeyError: # No summary key, so extract from LLM result
try:
score = extract_score(context)
except ValueError as exc:
raise LLMBadContextJSONError(
f"Successfully parsed JSON but it had no 'summary' key."
f" Then, the failover to extract score from raw context {context!r}"
f" failed due to: {exc}",
llm_results=llm_results,
) from exc
else:
try:
score = extract_score(context)
except ValueError as exc:
raise LLMBadContextJSONError(
f"Extracting score from raw context {context!r}"
f" failed due to: {exc}",
llm_results=llm_results,
) from exc
else:
llm_results.append(LLMResult(model="", date=""))
context = cleaned_text
# If we don't assign scores, just default to 5.
# why 5? Because we filter out 0s in another place
# and 5/10 is the other default I could come up with
# and 5/10 is the only default I could come up with
score = 5
success = True
# remove citations that collide with our grounded citations (for the answer LLM)
if not skip_citation_strip:
context = strip_citations(context)

if not success:
score = extract_score(context)
if used_text_only_fallback:
extras["used_images"] = False

Expand All @@ -270,8 +313,30 @@ async def map_fxn_summary(
doc=text.doc.model_dump(exclude={"embedding"}),
**text.model_dump(exclude={"embedding", "doc"}),
),
score=score, # pylint: disable=possibly-used-before-assignment
score=score,
**extras,
),
llm_result,
llm_results,
)


async def map_fxn_summary(**kwargs) -> tuple[Context | None, list[LLMResult]]:
if "_prior_attempt" in kwargs:
raise ValueError(
"_prior_attempt is reserved for internal use only, don't specify it."
)
try:
return await _map_fxn_summary(**kwargs)
except LLMBadContextJSONError as exc:
try:
return await _map_fxn_summary(**kwargs, _prior_attempt=exc)
except LLMBadContextJSONError as exc2:
logger.exception(
"Failed twice to create a context, abandoning it."
" Your model may not be capable of supporting JSON output"
" or our parsing technique could use some work. Try"
" a different model or specify `Settings(prompts={'use_json': False})`."
" Or, feel free to just ignore this message, as many contexts are"
" concurrently made and we're not attached to any one given context."
)
return None, exc2.llm_results
7 changes: 4 additions & 3 deletions src/paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,10 +678,11 @@ async def aget_evidence(
],
)

for _, llm_result in results:
session.add_tokens(llm_result)
for _, llm_results in results:
for r in llm_results:
session.add_tokens(r)

session.contexts += [r for r, _ in results]
session.contexts += [c for c, _ in results if c is not None]
return session

def query(
Expand Down
81 changes: 64 additions & 17 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from io import BytesIO
from pathlib import Path
from typing import cast
from unittest.mock import MagicMock, call
from unittest.mock import MagicMock, call, patch
from uuid import UUID

import httpx
Expand Down Expand Up @@ -493,21 +493,41 @@ async def test_evidence(docs_fixture: Docs) -> None:
), "Expected unique contexts"
texts = {c.text for c in evidence}
assert texts, "Below assertions require at least one text to be used"

# Okay, let's check we can get other evidence using the same underlying sources
other_evidence = (
await docs_fixture.aget_evidence(
PQASession(question="What is an acronym for explainable AI?"),
settings=debug_settings,
)
).contexts
orig_acompletion = litellm.acompletion
has_made_scoreless_context = False
no_score_context_body = "MAKEUNIQUE Explainable Artificial Intelligence (XAI)"

async def acompletion_that_breaks_first_context(*args, **kwargs):
completion = await orig_acompletion(*args, **kwargs)
nonlocal has_made_scoreless_context
if not has_made_scoreless_context:
assert len(completion.choices) == 1, "Test expects one choice"
completion.choices[0].message.content = no_score_context_body
has_made_scoreless_context = True
return completion

# Let's check we are resilient to bad context creation
with patch.object(litellm, "acompletion", acompletion_that_breaks_first_context):
# Let's also check we can get other evidence using the same underlying sources
other_evidence = (
await docs_fixture.aget_evidence(
PQASession(question="What is an acronym for explainable AI?"),
settings=debug_settings,
)
).contexts
(no_score_context,) = (
c for c in other_evidence if c.context == no_score_context_body
)
assert (
no_score_context.score == 1
), "Expected context without score to be downweighted"
assert texts.intersection(
{c.text for c in other_evidence}
), "We should be able to reuse sources across evidence calls"


@pytest.mark.asyncio
async def test_json_evidence(docs_fixture) -> None:
async def test_json_evidence(docs_fixture: Docs) -> None:
settings = Settings.from_name("fast")
settings.prompts.use_json = True
settings.prompts.summary_json_system = (
Expand All @@ -520,13 +540,38 @@ async def test_json_evidence(docs_fixture) -> None:
" author , and `relevance_score` is the relevance of `summary` to answer the"
" question (integer out of 10)."
)
evidence = (
await docs_fixture.aget_evidence(
PQASession(question="Who wrote this article?"),
settings=settings,
)
).contexts
assert evidence[0].author_name
orig_acompletion = litellm.acompletion
has_made_bad_json_context = False
bad_json_context = ( # Broken summary and relevance_score
'{\n "summary": "Complete of th'
'\n "author_name": "Sentinel value.",'
'\n "relevance_score": "A"\n}'
)

async def acompletion_that_breaks_first_context(*args, **kwargs):
completion = await orig_acompletion(*args, **kwargs)
nonlocal has_made_bad_json_context
if not has_made_bad_json_context:
assert len(completion.choices) == 1, "Test expects one choice"
completion.choices[0].message.content = bad_json_context
has_made_bad_json_context = True
return completion

# Let's check we are resilient to bad context creation
with patch.object(litellm, "acompletion", acompletion_that_breaks_first_context):
evidence = (
await docs_fixture.aget_evidence(
PQASession(question="Who wrote this article?"),
settings=settings,
)
).contexts
evidence_with_authors = [
c for c in evidence if hasattr(c, "author_name") and c.author_name
]
assert evidence_with_authors
assert all(
"sentinel" not in c.author_name.lower() for c in evidence_with_authors
), "Expected broken context retrying to work"


@pytest.mark.asyncio
Expand Down Expand Up @@ -960,6 +1005,7 @@ async def check_rate_limit(self, token_count: float, **kwargs) -> None:
"Echo", summary_llm_model=StubLLMModel(), settings=no_json_settings
)
).contexts
assert evidence
assert "Echo" in evidence[0].context

async def test_callback(result: LLMResult | str) -> None:
Expand All @@ -973,6 +1019,7 @@ async def test_callback(result: LLMResult | str) -> None:
settings=no_json_settings,
)
).contexts
assert evidence
assert "Echo" in evidence[0].context


Expand Down
Loading