diff --git a/src/paperqa/settings.py b/src/paperqa/settings.py index 91a5fec03..1a58e3211 100644 --- a/src/paperqa/settings.py +++ b/src/paperqa/settings.py @@ -7,6 +7,7 @@ import warnings from collections import defaultdict from collections.abc import Awaitable, Callable, Mapping, Sequence +from contextlib import suppress from enum import IntEnum, StrEnum from itertools import starmap from pydoc import locate @@ -74,7 +75,12 @@ ) from paperqa.readers import PDFParserFn from paperqa.types import Context, ParsedMedia, ParsedText -from paperqa.utils import hexdigest, parse_enrichment_irrelevance, pqa_directory +from paperqa.utils import ( + get_stable_str, + hexdigest, + parse_enrichment_irrelevance, + pqa_directory, +) logger = logging.getLogger(__name__) @@ -316,13 +322,10 @@ def _custom_serializer( if isinstance(self.parse_pdf, str): # Already JSON-compliant, so let's un-exclude data["parse_pdf"] = self.parse_pdf - elif ( - info.mode == "json" - and hasattr(self.parse_pdf, "__module__") - and hasattr(self.parse_pdf, "__name__") - ): + elif info.mode == "json": # If going to JSON, and we can get a FQN, do so for JSON compliance - data["parse_pdf"] = f"{self.parse_pdf.__module__}.{self.parse_pdf.__name__}" + with suppress(ValueError): # Suppress when not serialization-safe + data["parse_pdf"] = get_stable_str(self.parse_pdf) return data doc_filters: Sequence[Mapping[str, Any]] | None = Field( @@ -852,7 +855,7 @@ def get_index_name(self) -> str: first_segment, str(self.agent.index.use_absolute_paper_directory), self.embedding, - str(self.parsing.parse_pdf), # Don't use __name__ as lambda wouldn't differ + get_stable_str(self.parsing.parse_pdf, for_hash=True), str(self.parsing.reader_config["chunk_chars"]), str(self.parsing.reader_config["overlap"]), str(self.parsing.reader_config.get("full_page", False)), diff --git a/src/paperqa/utils.py b/src/paperqa/utils.py index 2f31a8a4c..2ef93ea72 100644 --- a/src/paperqa/utils.py +++ b/src/paperqa/utils.py @@ -10,7 +10,7 @@ from collections import Counter from collections.abc import Awaitable, Callable, Collection, Iterable, Iterator, Mapping from datetime import datetime -from functools import reduce +from functools import partial, reduce from http import HTTPStatus from pathlib import Path from typing import Any, BinaryIO, ClassVar, TypeVar @@ -686,3 +686,31 @@ def parse_enrichment_irrelevance(enrichment: str) -> tuple[bool, str]: return False, enrichment.strip() no_label_enrichment = enrichment[label_match.end() :] if label_match else enrichment return is_irrelevant, no_label_enrichment.strip() + + +def get_stable_str(fn: Callable, for_hash: bool = False) -> str: + """ + Get a stable string from a function, for serialization (default) or hashing. + + It avoids common pitfalls such as: + - All lambda functions having the same name ''. + - IDs in normal string representation of functions breaking cache keys. + """ + if ( + hasattr(fn, "__module__") + and hasattr(fn, "__name__") + and fn.__name__ != "" + and not isinstance(fn, partial) + ): + # Named function: use fully qualified name + return f"{fn.__module__}.{fn.__name__}" + if not for_hash: + raise ValueError( + f"Cannot form serialization-safe string representation for {fn}." + ) + # Fallback to lower-quality hash name + # NOTE: we could look into `fn.__code__`, but as `fn.__code__.co_code` + # is not stable across Python minor versions, + # this unfortunately makes the cache key incompatible across Python versions, + # which is too brittle + return type(fn).__name__ diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 850ff9e26..c176b5502 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -13,6 +13,7 @@ from collections.abc import AsyncIterable, Sequence from copy import deepcopy from datetime import datetime, timedelta +from functools import partial from io import BytesIO from pathlib import Path from typing import TYPE_CHECKING, Any, cast @@ -3291,6 +3292,12 @@ async def test_timeout_resilience() -> None: assert not llm_results +TEST_STUB_LAMBDA = lambda: 1 # noqa: E731 + + +TEST_STUB_PARTIAL = partial(pymupdf_parse_pdf_to_pages, kwarg=2) + + def test_parse_pdf_string_resolution() -> None: # Test with a valid string FQN pymupdf_str = Settings( @@ -3314,7 +3321,7 @@ def test_parse_pdf_string_resolution() -> None: ) assert "parse_pdf" not in pypdf_str.model_dump()["parsing"] - # Test directly passing a parser + # Test directly passing a normal parser pymupdf_fn = Settings(parsing=ParsingSettings(parse_pdf=pymupdf_parse_pdf_to_pages)) assert pymupdf_fn.parsing.parse_pdf == pymupdf_parse_pdf_to_pages assert ( @@ -3323,6 +3330,18 @@ def test_parse_pdf_string_resolution() -> None: ) assert "parse_pdf" not in pymupdf_fn.model_dump()["parsing"] + # Test directly passing a lambda parser + lambda_fn = Settings(parsing=ParsingSettings(parse_pdf=TEST_STUB_LAMBDA)) + assert lambda_fn.parsing.parse_pdf == TEST_STUB_LAMBDA + assert "parse_pdf" not in lambda_fn.model_dump(mode="json")["parsing"] + assert "parse_pdf" not in lambda_fn.model_dump()["parsing"] + + # Test directly passing a functools partial parser + partial_fn = Settings(parsing=ParsingSettings(parse_pdf=TEST_STUB_PARTIAL)) + assert partial_fn.parsing.parse_pdf == TEST_STUB_PARTIAL + assert "parse_pdf" not in partial_fn.model_dump(mode="json")["parsing"] + assert "parse_pdf" not in partial_fn.model_dump()["parsing"] + # Test a nonexistent FQN with pytest.raises(ValueError, match="Failed to locate"): Settings(parsing=ParsingSettings(parse_pdf="nonexistent.module.function"))