Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
encode_id,
format_bibtex,
get_citation_ids,
get_parenthetical_substrings,
maybe_get_date,
string_to_bytes,
)
Expand Down Expand Up @@ -377,19 +378,18 @@ def populate_formatted_answers_and_bib_from_raw_answer(
}
name_bib = {}

# https://regex101.com/r/h2Ca20/1
for parenthetical in re.findall(r"\(([^)]*)\)", formatted_without_references):

for parenthetical in get_parenthetical_substrings(formatted_without_references):
# now we replace eligible parentheticals with the deduped names
deduped_names = {
id_to_name_map.get(key, "") for key in get_citation_ids(parenthetical)
}

# replace the parenthetical with the deduped names
# while we preserve order and deduplicate
deduped_names: dict[str, None] = dict.fromkeys(
id_to_name_map.get(key) # type: ignore[misc]
for key in get_citation_ids(parenthetical)
if id_to_name_map.get(key)
)
if deduped_names:
formatted_without_references = formatted_without_references.replace(
parenthetical,
f"{', '.join(deduped_names)}",
f"({', '.join(deduped_names)})",
)
for deduped_name in deduped_names:
if (
Expand Down
27 changes: 25 additions & 2 deletions src/paperqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,31 @@ def extract_score(text: str) -> int:
raise ValueError(f"Failed to extract score from text {text!r}.")


def get_citation_ids(text: str) -> set[str]:
return set(re.findall(r"\bpqac-[a-zA-Z0-9]{8}\b", text))
def get_parenthetical_substrings(text: str) -> list[str]:
"""
Finds the all nested parenthetical substrings.

Args:
text: The input string to analyze.

Returns:
A list of parenthetical substrings.
"""
substrings = []
open_paren_indices = []
for i, char in enumerate(text):
if char == "(":
open_paren_indices.append(i)
elif char == ")" and open_paren_indices:
start_index = open_paren_indices.pop()
substrings.append(text[start_index : i + 1])
return substrings


def get_citation_ids(text: str) -> list[str]:
matches = re.findall(r"\bpqac-[a-zA-Z0-9]{8}\b", text)
# remove duplicates while preserving order
return list(dict.fromkeys(matches))


def extract_doi(reference: str) -> str:
Expand Down
119 changes: 118 additions & 1 deletion tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pathlib import Path
from typing import cast
from unittest.mock import MagicMock, call, patch
from uuid import UUID
from uuid import UUID, uuid4

import httpx
import litellm
Expand Down Expand Up @@ -2499,3 +2499,120 @@ def test_str_bytes_conversions(value: bytes) -> None:

# Validate that encoded string is valid base64
assert base64.b64decode(encoded_string) == value


tricky_test = (
"simple (pqac-a020507f) quote"
"TEST AND (easy OR mistaken OR not_context)"
" and another AND not context yes, end. (pqac-a020507f, pqac-4552861e)"
" duplicates ()"
)


@pytest.mark.parametrize(
("raw_text", "cleaned_text"),
[
("simple (pqac-a020507f) quote", "simple (text1) quote"),
(
"compound (pqac-a020507f) quote (pqac-a020507f, pqac-4552861e)",
"compound (text1) quote (text1, text2)",
),
(
"already replaced (pqac-a020507f, pqac-4552861e) quote (pqac-a020507f, pqac-4552861e)",
"already replaced (text1, text2) quote (text1, text2)",
),
(
"back (pqac-4552861e, pqac-a020507f) and forth (pqac-a020507f, pqac-4552861e)",
"back (text2, text1) and forth (text1, text2)",
),
(
"distractor (easy OR mistaken OR not_context) quote (pqac-a020507f, pqac-4552861e)",
"distractor (easy OR mistaken OR not_context) quote (text1, text2)",
),
("duplicates (pqac-a020507f, pqac-4552861f)", "duplicates (text1)"),
(
"compound duplicates (pqac-a020507f, pqac-4552861f, pqac-4552861e)",
"compound duplicates (text1, text2)",
),
(
"triples (pqac-a020507f, pqac-4552861e, pqac-d0a08f29) with single (pqac-d0a08f29)",
"triples (text1, text2, text3) with single (text3)",
),
(
"triples hallucinated (pqac-a020507f, pqac-1234567e, pqac-d0a08f29) with single (pqac-d0a08f29)",
"triples hallucinated (text1, text3) with single (text3)",
),
("with and (pqac-a020507f and pqac-4552861e)", "with and (text1, text2)"),
("extra stuff (pqac-4552861e, odd-unseen thing)", "extra stuff (text2)"),
(
"nested (parenthetical(pqac-4552861e, odd-unseen thing))",
"nested (parenthetical(text2))",
),
(
"dupe text name (pqac-d0a08f29, pqac-387deef6) xyz (pqac-387deef6)",
"dupe text name (text3) xyz (text3)",
),
],
)
def test_pqa_context_id_parsing(raw_text: str, cleaned_text: str) -> None:
session = PQASession(
id=uuid4(),
question="test",
raw_answer=raw_text,
contexts=[
Context(
id="pqac-a020507f",
context="blah blah",
text=Text(
name="text1",
text="blah blah",
doc=Doc(
docname="test_doc1",
citation="Test Doc1, 2025",
dockey="key1",
),
),
),
Context(
id="pqac-4552861e",
context="quote",
text=Text(
name="text2",
text="quote",
doc=Doc(
docname="test_doc2",
citation="Test Doc2, 2025",
dockey="key2",
),
),
),
Context(
id="pqac-d0a08f29",
context="quote",
text=Text(
name="text3",
text="quote",
doc=Doc(
docname="test_doc3",
citation="Test Doc3, 2025",
dockey="key3",
),
),
),
Context(
id="pqac-387deef6",
context="quote",
text=Text(
name="text3",
text="quote",
doc=Doc(
docname="test_doc3",
citation="Test Doc3, 2025",
dockey="key3",
),
),
),
],
)
session.populate_formatted_answers_and_bib_from_raw_answer()
assert session.answer == cleaned_text