Skip to content

Commit 51bded6

Browse files
authored
feat: Add show_progress parameter for independent progress bar control (google#227)
Progress bar visibility is now controlled independently of debug logging. Users can show/hide progress without affecting debug output. - Add show_progress parameter (defaults to True) - Add parameterized tests for all flag combinations
1 parent 8c88b68 commit 51bded6

File tree

3 files changed

+106
-4
lines changed

3 files changed

+106
-4
lines changed

langextract/annotation.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def annotate_documents(
201201
batch_length: int = 1,
202202
debug: bool = True,
203203
extraction_passes: int = 1,
204+
show_progress: bool = True,
204205
**kwargs,
205206
) -> Iterator[data.AnnotatedDocument]:
206207
"""Annotates a sequence of documents with NLP extractions.
@@ -223,6 +224,7 @@ def annotate_documents(
223224
standard single extraction.
224225
Values > 1 reprocess tokens multiple times, potentially increasing
225226
costs with the potential for a more thorough extraction.
227+
show_progress: Whether to show progress bar. Defaults to True.
226228
**kwargs: Additional arguments passed to LanguageModel.infer and Resolver.
227229
228230
Yields:
@@ -234,7 +236,13 @@ def annotate_documents(
234236

235237
if extraction_passes == 1:
236238
yield from self._annotate_documents_single_pass(
237-
documents, resolver, max_char_buffer, batch_length, debug, **kwargs
239+
documents,
240+
resolver,
241+
max_char_buffer,
242+
batch_length,
243+
debug,
244+
show_progress,
245+
**kwargs,
238246
)
239247
else:
240248
yield from self._annotate_documents_sequential_passes(
@@ -244,6 +252,7 @@ def annotate_documents(
244252
batch_length,
245253
debug,
246254
extraction_passes,
255+
show_progress,
247256
**kwargs,
248257
)
249258

@@ -254,6 +263,7 @@ def _annotate_documents_single_pass(
254263
max_char_buffer: int,
255264
batch_length: int,
256265
debug: bool,
266+
show_progress: bool = True,
257267
**kwargs,
258268
) -> Iterator[data.AnnotatedDocument]:
259269
"""Single-pass annotation logic (original implementation)."""
@@ -273,7 +283,7 @@ def _annotate_documents_single_pass(
273283
model_info = progress.get_model_info(self._language_model)
274284

275285
progress_bar = progress.create_extraction_progress_bar(
276-
batches, model_info=model_info, disable=not debug
286+
batches, model_info=model_info, disable=not show_progress
277287
)
278288

279289
chars_processed = 0
@@ -397,6 +407,7 @@ def _annotate_documents_sequential_passes(
397407
batch_length: int,
398408
debug: bool,
399409
extraction_passes: int,
410+
show_progress: bool = True,
400411
**kwargs,
401412
) -> Iterator[data.AnnotatedDocument]:
402413
"""Sequential extraction passes logic for improved recall."""
@@ -423,7 +434,8 @@ def _annotate_documents_sequential_passes(
423434
max_char_buffer,
424435
batch_length,
425436
debug=(debug and pass_num == 0),
426-
**kwargs, # Only show progress on first pass
437+
show_progress=show_progress if pass_num == 0 else False,
438+
**kwargs,
427439
):
428440
doc_id = annotated_doc.document_id
429441

@@ -472,6 +484,7 @@ def annotate_text(
472484
additional_context: str | None = None,
473485
debug: bool = True,
474486
extraction_passes: int = 1,
487+
show_progress: bool = True,
475488
**kwargs,
476489
) -> data.AnnotatedDocument:
477490
"""Annotates text with NLP extractions for text input.
@@ -488,6 +501,7 @@ def annotate_text(
488501
recall by finding additional entities. Defaults to 1, which performs
489502
standard single extraction. Values > 1 reprocess tokens multiple times,
490503
potentially increasing costs.
504+
show_progress: Whether to show progress bar. Defaults to True.
491505
**kwargs: Additional arguments for inference and resolver_lib.
492506
493507
Returns:
@@ -511,6 +525,7 @@ def annotate_text(
511525
batch_length,
512526
debug,
513527
extraction_passes,
528+
show_progress,
514529
**kwargs,
515530
)
516531
)

langextract/extraction.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def extract(
5656
fetch_urls: bool = True,
5757
prompt_validation_level: pv.PromptValidationLevel = pv.PromptValidationLevel.WARNING,
5858
prompt_validation_strict: bool = False,
59+
show_progress: bool = True,
5960
) -> typing.Any:
6061
"""Extracts structured information from text.
6162
@@ -149,6 +150,7 @@ def extract(
149150
raises on failures. Defaults to WARNING.
150151
prompt_validation_strict: When True and prompt_validation_level is ERROR,
151152
raises on non-exact matches (MATCH_FUZZY, MATCH_LESSER). Defaults to False.
153+
show_progress: Whether to show progress bar during extraction. Defaults to True.
152154
153155
Returns:
154156
An AnnotatedDocument with the extracted information when input is a
@@ -326,6 +328,7 @@ def extract(
326328
additional_context=additional_context,
327329
debug=debug,
328330
extraction_passes=extraction_passes,
331+
show_progress=show_progress,
329332
max_workers=max_workers,
330333
**alignment_kwargs,
331334
)
@@ -338,6 +341,7 @@ def extract(
338341
batch_length=batch_length,
339342
debug=debug,
340343
extraction_passes=extraction_passes,
344+
show_progress=show_progress,
341345
max_workers=max_workers,
342346
**alignment_kwargs,
343347
)

tests/init_test.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from unittest import mock
1919

2020
from absl.testing import absltest
21+
from absl.testing import parameterized
2122

2223
from langextract import prompting
2324
import langextract as lx
@@ -26,7 +27,7 @@
2627
from langextract.providers import schemas
2728

2829

29-
class InitTest(absltest.TestCase):
30+
class InitTest(parameterized.TestCase):
3031
"""Test cases for the main package functions."""
3132

3233
@mock.patch.object(
@@ -454,6 +455,88 @@ def test_tokenizer_module_exports_via_compatibility_shim(self):
454455
f"lx.tokenizer.{name} not accessible via compatibility shim",
455456
)
456457

458+
@parameterized.named_parameters(
459+
dict(
460+
testcase_name="show_progress_true_debug_false",
461+
show_progress=True,
462+
debug=False,
463+
expected_progress_disabled=False,
464+
),
465+
dict(
466+
testcase_name="show_progress_false_debug_false",
467+
show_progress=False,
468+
debug=False,
469+
expected_progress_disabled=True,
470+
),
471+
dict(
472+
testcase_name="show_progress_true_debug_true",
473+
show_progress=True,
474+
debug=True,
475+
expected_progress_disabled=False,
476+
),
477+
dict(
478+
testcase_name="show_progress_false_debug_true",
479+
show_progress=False,
480+
debug=True,
481+
expected_progress_disabled=True,
482+
),
483+
)
484+
@mock.patch("langextract.progress.create_extraction_progress_bar")
485+
@mock.patch("langextract.extraction.factory.create_model")
486+
def test_show_progress_controls_progress_bar(
487+
self,
488+
mock_create_model,
489+
mock_progress,
490+
show_progress,
491+
debug,
492+
expected_progress_disabled,
493+
):
494+
"""Test that show_progress parameter controls progress bar visibility."""
495+
mock_model = mock.MagicMock()
496+
mock_model.infer.return_value = [
497+
[
498+
types.ScoredOutput(
499+
output='{"extractions": []}',
500+
score=0.9,
501+
)
502+
]
503+
]
504+
mock_model.requires_fence_output = False
505+
mock_create_model.return_value = mock_model
506+
507+
mock_progress_bar = mock.MagicMock()
508+
mock_progress_bar.__iter__ = mock.MagicMock(
509+
return_value=iter([mock.MagicMock()])
510+
)
511+
mock_progress.return_value = mock_progress_bar
512+
513+
mock_examples = [
514+
lx.data.ExampleData(
515+
text="Example text",
516+
extractions=[
517+
lx.data.Extraction(
518+
extraction_class="entity",
519+
extraction_text="example",
520+
),
521+
],
522+
)
523+
]
524+
525+
lx.extract(
526+
text_or_documents="test text",
527+
prompt_description="extract entities",
528+
examples=mock_examples,
529+
api_key="test_key",
530+
show_progress=show_progress,
531+
debug=debug,
532+
)
533+
534+
mock_progress.assert_called()
535+
call_args = mock_progress.call_args
536+
self.assertEqual(
537+
call_args.kwargs.get("disable", False), expected_progress_disabled
538+
)
539+
457540

458541
if __name__ == "__main__":
459542
absltest.main()

0 commit comments

Comments
 (0)