Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed issues related to duplicate code and documentation
  • Loading branch information
bharath03-a committed Sep 27, 2025
commit 0fbbc2e89d0fd65276b9eb1f900a49fae4b16c77
137 changes: 75 additions & 62 deletions langextract/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@ def _process_batch_with_retry(
Args:
batch_prompts: List of prompts for the batch
batch: List of TextChunk objects corresponding to the prompts
retry_transient_errors: Whether to retry on transient errors
max_retries: Maximum number of retry attempts
retry_initial_delay: Initial delay before retry
retry_backoff_factor: Backoff multiplier for retries
retry_max_delay: Maximum delay between retries
**kwargs: Additional arguments passed to the language model

Yields:
Expand All @@ -237,8 +242,7 @@ def _process_batch_with_retry(
)
)

for result in batch_results:
yield result
yield from batch_results
return

except Exception as e:
Expand Down Expand Up @@ -279,8 +283,7 @@ def _process_batch_with_retry(
)
raise

for result in individual_results:
yield result
yield from individual_results

def _process_single_chunk_with_retry(
self,
Expand Down Expand Up @@ -308,68 +311,31 @@ def _process_single_chunk_with_retry(
Returns:
List containing a single ScoredOutput for this chunk
"""
last_exception = None
delay = retry_initial_delay

for attempt in range(max_retries + 1):
try:
batch_results = list(
self._language_model.infer(
batch_prompts=[prompt],
**kwargs,
)
)

if not batch_results:
raise exceptions.InferenceOutputError(
f"No results returned for chunk in document {chunk.document_id}"
)

return batch_results[0]

except Exception as e:
last_exception = e

if not retry_transient_errors or not retry_utils.is_transient_error(e):
logging.debug(
"Not retrying chunk processing: retry_disabled=%s,"
" is_transient=%s, error=%s",
not retry_transient_errors,
retry_utils.is_transient_error(e),
str(e),
)
raise

if attempt >= max_retries:
logging.error(
"Chunk processing failed after %d retries: %s",
max_retries,
str(e),
# Use the retry decorator with custom parameters
@retry_utils.retry_chunk_processing(
max_retries=max_retries,
initial_delay=retry_initial_delay,
backoff_factor=retry_backoff_factor,
max_delay=retry_max_delay,
enabled=retry_transient_errors,
)
def _process_chunk():
batch_results = list(
self._language_model.infer(
batch_prompts=[prompt],
**kwargs,
)
raise

current_delay = min(delay, retry_max_delay)

import random

jitter_amount = current_delay * 0.1 * random.random()
current_delay += jitter_amount
)

logging.warning(
"Chunk processing failed on attempt %d/%d due to transient error:"
" %s. Retrying in %.2f seconds...",
attempt + 1,
max_retries + 1,
str(e),
current_delay,
if not batch_results:
raise exceptions.InferenceOutputError(
f"No results returned for chunk in document {chunk.document_id}"
)

time.sleep(current_delay)
delay = min(delay * retry_backoff_factor, retry_max_delay)
return batch_results[0]

if last_exception:
raise last_exception
raise RuntimeError("Chunk retry logic failed unexpectedly")
return _process_chunk()

def annotate_documents(
self,
Expand Down Expand Up @@ -408,6 +374,11 @@ def annotate_documents(
Values > 1 reprocess tokens multiple times, potentially increasing
costs with the potential for a more thorough extraction.
show_progress: Whether to show progress bar. Defaults to True.
retry_transient_errors: Whether to retry on transient errors. Defaults to True.
max_retries: Maximum number of retry attempts. Defaults to 3.
retry_initial_delay: Initial delay before retry in seconds. Defaults to 1.0.
retry_backoff_factor: Backoff multiplier for retries. Defaults to 2.0.
retry_max_delay: Maximum delay between retries in seconds. Defaults to 60.0.
**kwargs: Additional arguments passed to LanguageModel.infer and Resolver.

Yields:
Expand Down Expand Up @@ -466,7 +437,25 @@ def _annotate_documents_single_pass(
retry_max_delay: float = 60.0,
**kwargs,
) -> Iterator[data.AnnotatedDocument]:
"""Single-pass annotation logic (original implementation)."""
"""Single-pass annotation logic (original implementation).

Args:
documents: Iterable of documents to annotate
resolver: Resolver for processing inference results
max_char_buffer: Maximum character buffer for chunking
batch_length: Number of chunks to process in each batch
debug: Whether to enable debug logging
show_progress: Whether to show progress bar
retry_transient_errors: Whether to retry on transient errors
max_retries: Maximum number of retry attempts
retry_initial_delay: Initial delay before retry
retry_backoff_factor: Backoff multiplier for retries
retry_max_delay: Maximum delay between retries
**kwargs: Additional arguments passed to language model

Yields:
AnnotatedDocument objects with extracted data
"""

logging.info("Starting document annotation.")
doc_iter, doc_iter_for_chunks = itertools.tee(documents, 2)
Expand Down Expand Up @@ -622,7 +611,26 @@ def _annotate_documents_sequential_passes(
retry_max_delay: float = 60.0,
**kwargs,
) -> Iterator[data.AnnotatedDocument]:
"""Sequential extraction passes logic for improved recall."""
"""Sequential extraction passes logic for improved recall.

Args:
documents: Iterable of documents to annotate
resolver: Resolver for processing inference results
max_char_buffer: Maximum character buffer for chunking
batch_length: Number of chunks to process in each batch
debug: Whether to enable debug logging
extraction_passes: Number of extraction passes to perform
show_progress: Whether to show progress bar
retry_transient_errors: Whether to retry on transient errors
max_retries: Maximum number of retry attempts
retry_initial_delay: Initial delay before retry
retry_backoff_factor: Backoff multiplier for retries
retry_max_delay: Maximum delay between retries
**kwargs: Additional arguments passed to language model

Yields:
AnnotatedDocument objects with merged extracted data
"""

logging.info(
"Starting sequential extraction passes for improved recall with %d"
Expand Down Expand Up @@ -722,6 +730,11 @@ def annotate_text(
standard single extraction. Values > 1 reprocess tokens multiple times,
potentially increasing costs.
show_progress: Whether to show progress bar. Defaults to True.
retry_transient_errors: Whether to retry on transient errors. Defaults to True.
max_retries: Maximum number of retry attempts. Defaults to 3.
retry_initial_delay: Initial delay before retry in seconds. Defaults to 1.0.
retry_backoff_factor: Backoff multiplier for retries. Defaults to 2.0.
retry_max_delay: Maximum delay between retries in seconds. Defaults to 60.0.
**kwargs: Additional arguments for inference and resolver_lib.

Returns:
Expand Down
94 changes: 61 additions & 33 deletions langextract/retry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
from __future__ import annotations

import functools
import random
import time
from typing import Any, Callable, TypeVar

from absl import logging

from langextract.core import exceptions

T = TypeVar("T")

# Transient error patterns that should trigger retries
Expand Down Expand Up @@ -79,10 +78,60 @@ def is_transient_error(error: Exception) -> bool:

# Check for transient exception types
is_transient_type = error_type in TRANSIENT_EXCEPTION_TYPES

return is_transient_pattern or is_transient_type


def execute_retry_with_backoff(
attempt: int,
max_retries: int,
delay: float,
max_delay: float,
backoff_factor: float,
error: Exception,
operation_name: str = "operation",
) -> float:
"""Execute retry logic with exponential backoff and jitter.

Args:
attempt: Current attempt number (0-based)
max_retries: Maximum number of retries
delay: Current delay value
max_delay: Maximum delay value
backoff_factor: Factor to multiply delay by
error: The exception that occurred
operation_name: Name of the operation for logging

Returns:
New delay value for next iteration
"""
if attempt >= max_retries:
logging.error(
"%s failed after %d retries: %s",
operation_name,
max_retries,
str(error),
)
raise error

current_delay = min(delay, max_delay)

jitter_amount = current_delay * 0.1 * random.random()
current_delay += jitter_amount

logging.warning(
"%s failed on attempt %d/%d due to transient error: %s. "
"Retrying in %.2f seconds...",
operation_name,
attempt + 1,
max_retries + 1,
str(error),
current_delay,
)

time.sleep(current_delay)
return min(delay * backoff_factor, max_delay)


def retry_on_transient_errors(
max_retries: int = 3,
initial_delay: float = 1.0,
Expand Down Expand Up @@ -136,8 +185,6 @@ def wrapper(*args: Any, **kwargs: Any) -> T:

# Add jitter to prevent thundering herd.
if jitter:
import random

jitter_amount = current_delay * 0.1 * random.random()
current_delay += jitter_amount

Expand Down Expand Up @@ -210,36 +257,17 @@ def wrapper(*args: Any, **kwargs: Any) -> T:
)
raise

# If we've exhausted retries, raise the exception
if attempt >= max_retries:
logging.error(
"Chunk processing failed after %d retries: %s",
max_retries,
str(e),
)
raise

# Calculate delay with exponential backoff
current_delay = min(delay, max_delay)

# Add jitter to prevent thundering herd
import random

jitter_amount = current_delay * 0.1 * random.random()
current_delay += jitter_amount

logging.warning(
"Chunk processing failed on attempt %d/%d due to transient error:"
" %s. Retrying in %.2f seconds...",
attempt + 1,
max_retries + 1,
str(e),
current_delay,
# Execute retry logic with backoff
delay = execute_retry_with_backoff(
attempt=attempt,
max_retries=max_retries,
delay=delay,
max_delay=max_delay,
backoff_factor=backoff_factor,
error=e,
operation_name="Chunk processing",
)

time.sleep(current_delay)
delay = min(delay * backoff_factor, max_delay)

# This should never be reached, but just in case
if last_exception:
raise last_exception
Expand Down
47 changes: 21 additions & 26 deletions tests/annotation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,39 +1252,34 @@ def test_retry_parameters_passed_to_single_chunk_processing(self):
documents = [data.Document(text="Test document", document_id="test_doc")]

mock_result = types.ScoredOutput(score=1.0, output='{"extractions": []}')
self.mock_language_model.infer.return_value = iter([[mock_result]])

with mock.patch.object(
self.annotator,
"_process_single_chunk_with_retry",
return_value=[mock_result],
) as mock_single_chunk:
with mock.patch.object(
self.annotator,
"_process_batch_with_retry",
side_effect=Exception("Batch processing failed"),
):
try:
list(
self.annotator.annotate_documents(
documents=documents,
retry_transient_errors=True,
max_retries=3,
retry_initial_delay=1.0,
retry_backoff_factor=2.0,
retry_max_delay=60.0,
)
self.mock_language_model.infer.side_effect = Exception(
"503 The model is overloaded"
)

results = list(
self.annotator.annotate_documents(
documents=documents,
retry_transient_errors=True,
max_retries=3,
retry_initial_delay=1.0,
retry_backoff_factor=2.0,
retry_max_delay=60.0,
)
except Exception:
pass

if mock_single_chunk.called:
call_kwargs = mock_single_chunk.call_args[1]
self.assertEqual(call_kwargs["retry_transient_errors"], True)
self.assertEqual(call_kwargs["max_retries"], 3)
self.assertEqual(call_kwargs["retry_initial_delay"], 1.0)
self.assertEqual(call_kwargs["retry_backoff_factor"], 2.0)
self.assertEqual(call_kwargs["retry_max_delay"], 60.0)
)

self.assertTrue(mock_single_chunk.called)
call_kwargs = mock_single_chunk.call_args[1]
self.assertEqual(call_kwargs["retry_transient_errors"], True)
self.assertEqual(call_kwargs["max_retries"], 3)
self.assertEqual(call_kwargs["retry_initial_delay"], 1.0)
self.assertEqual(call_kwargs["retry_backoff_factor"], 2.0)
self.assertEqual(call_kwargs["retry_max_delay"], 60.0)


if __name__ == "__main__":
Expand Down
Loading