Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ Here's a short demo of how to do this:
```python
from paperqa.clients import DocMetadataClient, ALL_CLIENTS

client = DocMetadataClient(clients=ALL_CLIENTS)
client = DocMetadataClient(metadata_clients=ALL_CLIENTS)
details = await client.query(title="Augmenting language models with chemistry tools")

print(details.formatted_citation)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
"aiohttp>=3.10.6", # TODO: remove in favor of httpx, pin for aiohttp.ClientConnectionResetError
"anyio",
"fhaviary[llm]>=0.20", # For Environment.get_id
"fhlmi>=0.25.4", # For LLM reasoning
Expand Down
44 changes: 25 additions & 19 deletions src/paperqa/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Awaitable, Collection, Coroutine, Sequence
from typing import Any, cast

import aiohttp
import httpx
from lmi.utils import gather_with_concurrency
from pydantic import BaseModel, ConfigDict, Field

Expand Down Expand Up @@ -60,11 +60,11 @@ def provider_queries(
return [p.query(query) for p in self.providers]

def processor_queries(
self, doc_details: DocDetails, session: aiohttp.ClientSession
self, doc_details: DocDetails, client: httpx.AsyncClient
) -> list[Coroutine[Any, Any, DocDetails]]:
"""Set up process coroutines for each contained metadata post-processor."""
return [
p.process(copy.copy(doc_details), session=session) for p in self.processors
p.process(copy.copy(doc_details), client=client) for p in self.processors
]

def __repr__(self) -> str:
Expand All @@ -76,29 +76,29 @@ def __repr__(self) -> str:
class DocMetadataClient:
def __init__(
self,
session: aiohttp.ClientSession | None = None,
clients: (
http_client: httpx.AsyncClient | None = None,
metadata_clients: (
Collection[type[MetadataPostProcessor | MetadataProvider]]
| Sequence[Collection[type[MetadataPostProcessor | MetadataProvider]]]
) = DEFAULT_CLIENTS,
) -> None:
"""Metadata client for querying multiple metadata providers and processors.

Args:
session: outer scope aiohttp session to allow for connection pooling
clients: list of MetadataProvider and MetadataPostProcessor classes to query;
http_client: Async HTTP client to allow for connection pooling.
metadata_clients: list of MetadataProvider and MetadataPostProcessor classes to query;
if nested, will query in order looking for termination criteria after each.
Will terminate early if either DocDetails.is_hydration_needed is False OR if
all requested fields are present in the DocDetails object.
"""
self._session = session
self._http_client = http_client
self.tasks: list[DocMetadataTask] = []

# first see if we are nested; i.e. we want order
if isinstance(clients, Sequence) and all(
isinstance(sub_clients, Collection) for sub_clients in clients
if isinstance(metadata_clients, Sequence) and all(
isinstance(sub_clients, Collection) for sub_clients in metadata_clients
):
for sub_clients in clients:
for sub_clients in metadata_clients:
self.tasks.append(
DocMetadataTask(
providers=[
Expand All @@ -119,18 +119,20 @@ def __init__(
)
)
# otherwise, we are a flat collection
if not self.tasks and all(not isinstance(c, Collection) for c in clients):
if not self.tasks and all(
not isinstance(c, Collection) for c in metadata_clients
):
self.tasks.append(
DocMetadataTask(
providers=[
c if isinstance(c, MetadataProvider) else c() # type: ignore[redundant-expr]
for c in clients
for c in metadata_clients
if (isinstance(c, type) and issubclass(c, MetadataProvider))
or isinstance(c, MetadataProvider)
],
processors=[
c if isinstance(c, MetadataPostProcessor) else c() # type: ignore[redundant-expr]
for c in clients
for c in metadata_clients
if (
isinstance(c, type) and issubclass(c, MetadataPostProcessor)
)
Expand All @@ -144,9 +146,13 @@ def __init__(

async def query(self, **kwargs) -> DocDetails | None:

session = aiohttp.ClientSession() if self._session is None else self._session
client = (
httpx.AsyncClient(timeout=10.0)
if self._http_client is None
else self._http_client
)

query_args = kwargs if "session" in kwargs else kwargs | {"session": session}
query_args = kwargs if "client" in kwargs else kwargs | {"client": client}

all_doc_details: DocDetails | None = None

Expand Down Expand Up @@ -175,7 +181,7 @@ async def query(self, **kwargs) -> DocDetails | None:
sum(
await gather_with_concurrency(
len(task.processors),
task.processor_queries(doc_details, session),
task.processor_queries(doc_details, client),
)
)
or None
Expand All @@ -194,8 +200,8 @@ async def query(self, **kwargs) -> DocDetails | None:
)
break

if self._session is None:
await session.close()
if self._http_client is None:
await client.aclose()

return all_doc_details

Expand Down
6 changes: 3 additions & 3 deletions src/paperqa/clients/client_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Collection
from typing import Any, Generic, TypeVar

import aiohttp
import httpx
from pydantic import (
BaseModel,
ConfigDict,
Expand All @@ -27,7 +27,7 @@
class ClientQuery(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

session: aiohttp.ClientSession
client: httpx.AsyncClient


class TitleAuthorQuery(ClientQuery):
Expand Down Expand Up @@ -121,7 +121,7 @@ async def query(self, query: dict) -> DocDetails | None:
f" {self.__class__.__name__}."
)
# we're suppressing this error to not fail on 403 or 500 errors from providers
except aiohttp.ClientError:
except httpx.RequestError:
logger.warning(
"Client error for"
f" {client_query.doi if isinstance(client_query, DOIQuery) else client_query.title} in"
Expand Down
74 changes: 33 additions & 41 deletions src/paperqa/clients/crossref.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any
from urllib.parse import quote

import aiohttp
import httpx
from anyio import open_file
from lmi.utils import CROSSREF_KEY_HEADER
from tenacity import (
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_crossref_mailto() -> str:
)
async def doi_to_bibtex(
doi: str,
session: aiohttp.ClientSession,
client: httpx.AsyncClient,
missing_replacements: Mapping[str, str | list[str]] | None = None,
) -> str:
"""Get a bibtex entry from a DOI via Crossref, replacing the key if possible.
Expand All @@ -126,15 +126,15 @@ async def doi_to_bibtex(
missing_replacements = {}
FORBIDDEN_KEY_CHARACTERS = {"_", " ", "-", "/"}
# get DOI via crossref
async with session.get(
r = await client.get(
f"https://api.crossref.org/works/{quote(doi, safe='')}/transform/application/x-bibtex",
headers=crossref_headers(),
) as r:
if not r.ok:
raise DOINotFoundError(
f"Per HTTP status code {r.status}, could not resolve DOI {doi}."
)
data = await r.text()
)
if not r.is_success:
raise DOINotFoundError(
f"Per HTTP status code {r.status_code}, could not resolve DOI {doi}."
)
data = r.text
# must make new key
key = data.split("{")[1].split(",")[0]
new_key = remove_substrings(key, FORBIDDEN_KEY_CHARACTERS)
Expand Down Expand Up @@ -165,7 +165,7 @@ async def doi_to_bibtex(

async def parse_crossref_to_doc_details(
message: dict[str, Any],
session: aiohttp.ClientSession,
client: httpx.AsyncClient,
query_bibtex: bool = False,
) -> DocDetails:

Expand All @@ -185,7 +185,7 @@ async def parse_crossref_to_doc_details(
# since we now create the bibtex from scratch
if query_bibtex:
bibtex = await doi_to_bibtex(
message["DOI"], session, missing_replacements=fallback_data
message["DOI"], client, missing_replacements=fallback_data
)
# track the origin of the bibtex entry for debugging
bibtex_source = BibTeXSource.CROSSREF.value
Expand Down Expand Up @@ -250,7 +250,7 @@ async def parse_crossref_to_doc_details(
stop=stop_after_attempt(5),
)
async def get_doc_details_from_crossref( # noqa: PLR0912
session: aiohttp.ClientSession,
client: httpx.AsyncClient,
doi: str | None = None,
authors: list[str] | None = None,
title: str | None = None,
Expand Down Expand Up @@ -302,24 +302,24 @@ async def get_doc_details_from_crossref( # noqa: PLR0912
}
)

async with session.get(
response = await client.get(
url,
params=params,
headers=crossref_headers(),
timeout=aiohttp.ClientTimeout(CROSSREF_API_REQUEST_TIMEOUT),
) as response:
try:
response.raise_for_status()
except aiohttp.ClientResponseError as exc:
raise DOINotFoundError(f"Could not find paper given {inputs_msg}.") from exc
try:
response_data = await response.json()
except json.JSONDecodeError as exc:
# JSONDecodeError: Crossref didn't answer with JSON, perhaps HTML
raise DOINotFoundError( # Use DOINotFoundError so we fall back to Google Scholar
f"Crossref API did not return JSON for {inputs_msg}, instead it"
f" responded with text: {await response.text()}"
) from exc
timeout=CROSSREF_API_REQUEST_TIMEOUT,
)
try:
response.raise_for_status()
except httpx.HTTPStatusError as exc:
raise DOINotFoundError(f"Could not find paper given {inputs_msg}.") from exc
try:
response_data = response.json()
except json.JSONDecodeError as exc:
# JSONDecodeError: Crossref didn't answer with JSON, perhaps HTML
raise DOINotFoundError( # Use DOINotFoundError so we fall back to Google Scholar
f"Crossref API did not return JSON for {inputs_msg}, instead it"
f" responded with text: {response.text}"
) from exc
if response_data["status"] == "failed":
raise DOINotFoundError(
f"Crossref API returned a failed status for {inputs_msg}."
Expand All @@ -345,7 +345,7 @@ async def get_doc_details_from_crossref( # noqa: PLR0912
if doi is not None and message["DOI"] != doi:
raise DOINotFoundError(f"DOI ({inputs_msg}) not found in Crossref")

return await parse_crossref_to_doc_details(message, session, query_bibtex)
return await parse_crossref_to_doc_details(message, client, query_bibtex)


@retry(
Expand All @@ -364,24 +364,16 @@ async def download_retracted_dataset(
"""
url = f"https://api.labs.crossref.org/data/retractionwatch?{get_crossref_mailto()}"

async with (
aiohttp.ClientSession() as session,
session.get(
url,
timeout=aiohttp.ClientTimeout(total=300),
) as response,
):
async with httpx.AsyncClient(timeout=300) as client:
response = await client.get(url)
response.raise_for_status()

logger.info(
f"Retraction data was not cashed. Downloading retraction data from {url}..."
)

async with await open_file(str(retraction_data_path), "wb") as f:
while True:
chunk = await response.content.read(1024)
if not chunk:
break
async for chunk in response.aiter_bytes(chunk_size=1024):
await f.write(chunk)

if os.path.getsize(str(retraction_data_path)) == 0:
Expand All @@ -392,12 +384,12 @@ class CrossrefProvider(DOIOrTitleBasedProvider):
async def _query(self, query: TitleAuthorQuery | DOIQuery) -> DocDetails | None:
if isinstance(query, DOIQuery):
return await get_doc_details_from_crossref(
doi=query.doi, session=query.session, fields=query.fields
doi=query.doi, client=query.client, fields=query.fields
)
return await get_doc_details_from_crossref(
title=query.title,
authors=query.authors,
session=query.session,
client=query.client,
title_similarity_threshold=query.title_similarity_threshold,
fields=query.fields,
)
6 changes: 3 additions & 3 deletions src/paperqa/clients/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Callable

import aiohttp
import httpx


class DOINotFoundError(Exception):
Expand All @@ -11,11 +11,11 @@ def __init__(self, message="DOI not found") -> None:

def make_flaky_ssl_error_predicate(host: str) -> Callable[[BaseException], bool]:
def predicate(exc: BaseException) -> bool:
# Seen with both Semantic Scholar and Crossref:
# > aiohttp.client_exceptions.ClientConnectorError:
# > Cannot connect to host api.host.org:443 ssl:default
# > [nodename nor servname provided, or not known]
# SEE: https://github.com/aio-libs/aiohttp/blob/v3.10.5/aiohttp/client_exceptions.py#L193-L196
return isinstance(exc, aiohttp.ClientConnectorError) and exc.host == host
# Then we migrated to httpx
return isinstance(exc, httpx.ConnectError) and exc.request.url.host == host

return predicate
Loading