Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
from collections import defaultdict
from collections.abc import Awaitable, Callable, Mapping, Sequence
from contextlib import suppress
from enum import IntEnum, StrEnum
from itertools import starmap
from pydoc import locate
Expand Down Expand Up @@ -74,7 +75,12 @@
)
from paperqa.readers import PDFParserFn
from paperqa.types import Context, ParsedMedia, ParsedText
from paperqa.utils import hexdigest, parse_enrichment_irrelevance, pqa_directory
from paperqa.utils import (
get_stable_str,
hexdigest,
parse_enrichment_irrelevance,
pqa_directory,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -316,13 +322,10 @@ def _custom_serializer(
if isinstance(self.parse_pdf, str):
# Already JSON-compliant, so let's un-exclude
data["parse_pdf"] = self.parse_pdf
elif (
info.mode == "json"
and hasattr(self.parse_pdf, "__module__")
and hasattr(self.parse_pdf, "__name__")
):
elif info.mode == "json":
# If going to JSON, and we can get a FQN, do so for JSON compliance
data["parse_pdf"] = f"{self.parse_pdf.__module__}.{self.parse_pdf.__name__}"
with suppress(ValueError): # Suppress when not serialization-safe
data["parse_pdf"] = get_stable_str(self.parse_pdf)
return data

doc_filters: Sequence[Mapping[str, Any]] | None = Field(
Expand Down Expand Up @@ -852,7 +855,7 @@ def get_index_name(self) -> str:
first_segment,
str(self.agent.index.use_absolute_paper_directory),
self.embedding,
str(self.parsing.parse_pdf), # Don't use __name__ as lambda wouldn't differ
get_stable_str(self.parsing.parse_pdf, for_hash=True),
str(self.parsing.reader_config["chunk_chars"]),
str(self.parsing.reader_config["overlap"]),
str(self.parsing.reader_config.get("full_page", False)),
Expand Down
30 changes: 29 additions & 1 deletion src/paperqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections import Counter
from collections.abc import Awaitable, Callable, Collection, Iterable, Iterator, Mapping
from datetime import datetime
from functools import reduce
from functools import partial, reduce
from http import HTTPStatus
from pathlib import Path
from typing import Any, BinaryIO, ClassVar, TypeVar
Expand Down Expand Up @@ -686,3 +686,31 @@ def parse_enrichment_irrelevance(enrichment: str) -> tuple[bool, str]:
return False, enrichment.strip()
no_label_enrichment = enrichment[label_match.end() :] if label_match else enrichment
return is_irrelevant, no_label_enrichment.strip()


def get_stable_str(fn: Callable, for_hash: bool = False) -> str:
"""
Get a stable string from a function, for serialization (default) or hashing.

It avoids common pitfalls such as:
- All lambda functions having the same name '<lambda>'.
- IDs in normal string representation of functions breaking cache keys.
"""
if (
hasattr(fn, "__module__")
and hasattr(fn, "__name__")
and fn.__name__ != "<lambda>"
and not isinstance(fn, partial)
):
# Named function: use fully qualified name
return f"{fn.__module__}.{fn.__name__}"
if not for_hash:
raise ValueError(
f"Cannot form serialization-safe string representation for {fn}."
)
# Fallback to lower-quality hash name
# NOTE: we could look into `fn.__code__`, but as `fn.__code__.co_code`
# is not stable across Python minor versions,
# this unfortunately makes the cache key incompatible across Python versions,
# which is too brittle
return type(fn).__name__
21 changes: 20 additions & 1 deletion tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from collections.abc import AsyncIterable, Sequence
from copy import deepcopy
from datetime import datetime, timedelta
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast
Expand Down Expand Up @@ -3291,6 +3292,12 @@ async def test_timeout_resilience() -> None:
assert not llm_results


TEST_STUB_LAMBDA = lambda: 1 # noqa: E731


TEST_STUB_PARTIAL = partial(pymupdf_parse_pdf_to_pages, kwarg=2)


def test_parse_pdf_string_resolution() -> None:
# Test with a valid string FQN
pymupdf_str = Settings(
Expand All @@ -3314,7 +3321,7 @@ def test_parse_pdf_string_resolution() -> None:
)
assert "parse_pdf" not in pypdf_str.model_dump()["parsing"]

# Test directly passing a parser
# Test directly passing a normal parser
pymupdf_fn = Settings(parsing=ParsingSettings(parse_pdf=pymupdf_parse_pdf_to_pages))
assert pymupdf_fn.parsing.parse_pdf == pymupdf_parse_pdf_to_pages
assert (
Expand All @@ -3323,6 +3330,18 @@ def test_parse_pdf_string_resolution() -> None:
)
assert "parse_pdf" not in pymupdf_fn.model_dump()["parsing"]

# Test directly passing a lambda parser
lambda_fn = Settings(parsing=ParsingSettings(parse_pdf=TEST_STUB_LAMBDA))
assert lambda_fn.parsing.parse_pdf == TEST_STUB_LAMBDA
assert "parse_pdf" not in lambda_fn.model_dump(mode="json")["parsing"]
assert "parse_pdf" not in lambda_fn.model_dump()["parsing"]

# Test directly passing a functools partial parser
partial_fn = Settings(parsing=ParsingSettings(parse_pdf=TEST_STUB_PARTIAL))
assert partial_fn.parsing.parse_pdf == TEST_STUB_PARTIAL
assert "parse_pdf" not in partial_fn.model_dump(mode="json")["parsing"]
assert "parse_pdf" not in partial_fn.model_dump()["parsing"]

# Test a nonexistent FQN
with pytest.raises(ValueError, match="Failed to locate"):
Settings(parsing=ParsingSettings(parse_pdf="nonexistent.module.function"))
Expand Down