Skip to content

Commit 0ca4843

Browse files
authored
Fix: Pass project parameter correctly to storage.Client in Gemini Batch API (#286)
Fixes #285 - Batch API failing with "Required parameter: project" error The project and location parameters from language_model_params were not being passed through to _submit_file(), causing storage.Client to be created without a project context. Changes: - Updated _submit_file() to accept project and location parameters - Modified _process_batch() to pass these parameters through - Added test to verify project parameter is correctly passed - Updated tests to include required batch config parameters
1 parent 0c1af87 commit 0ca4843

File tree

2 files changed

+131
-11
lines changed

2 files changed

+131
-11
lines changed

langextract/providers/gemini_batch.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ def _submit_file(
300300
requests: Sequence[dict],
301301
display: str,
302302
retention_days: int | None,
303+
project: str | None = None,
304+
location: str | None = None,
303305
) -> genai.types.BatchJob:
304306
"""Submit a file-based batch job to Vertex AI using GCS storage.
305307
@@ -317,6 +319,10 @@ def _submit_file(
317319
display: Display name for the batch job, used for identification and
318320
as part of the GCS blob name.
319321
retention_days: Days to keep GCS data. If set, applies lifecycle rule.
322+
project: Optional GCP project ID. If not provided, will attempt to
323+
determine from client or environment.
324+
location: Optional GCP region/location. If not provided, will attempt to
325+
determine from client or use default.
320326
321327
Returns:
322328
BatchJob object that can be polled for completion status.
@@ -336,7 +342,7 @@ def _submit_file(
336342
line = {"key": f"{_KEY_IDX}{idx}", "request": req}
337343
f.write(json.dumps(line, ensure_ascii=False) + "\n")
338344

339-
project, location = _get_project_location(client)
345+
project, location = _get_project_location(client, project, location)
340346
bucket_name = _get_bucket_name(project, location)
341347
blob_name = f"batch-input/{display}-{uuid.uuid4().hex}.jsonl"
342348

@@ -807,7 +813,15 @@ def _process_batch(
807813
)
808814
for p in batch_prompts
809815
]
810-
job = _submit_file(client, model_id, requests, display, cfg.retention_days)
816+
job = _submit_file(
817+
client,
818+
model_id,
819+
requests,
820+
display,
821+
cfg.retention_days,
822+
project,
823+
location,
824+
)
811825
if cfg.on_job_create:
812826
try:
813827
cfg.on_job_create(job)

tests/test_gemini_batch_api.py

Lines changed: 115 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,13 @@ def test_batch_routing_vertex(self, mock_client_cls):
102102
vertexai=True,
103103
project="test-project",
104104
location=gb._DEFAULT_LOCATION,
105-
batch={"enabled": True, "threshold": 2, "poll_interval": 1},
105+
batch={
106+
"enabled": True,
107+
"threshold": 2,
108+
"poll_interval": 1,
109+
"enable_caching": False,
110+
"retention_days": None,
111+
},
106112
)
107113
prompts = ["p1", "p2"]
108114
outs = list(model.infer(prompts))
@@ -156,7 +162,12 @@ def test_realtime_when_below_threshold(self, mock_client_cls):
156162
vertexai=True,
157163
project="p",
158164
location="l",
159-
batch={"enabled": True, "threshold": 10},
165+
batch={
166+
"enabled": True,
167+
"threshold": 10,
168+
"enable_caching": False,
169+
"retention_days": None,
170+
},
160171
)
161172
outs = list(model.infer(["hello"]))
162173

@@ -195,7 +206,12 @@ def test_batch_with_schema(self, mock_client_cls):
195206
project="p",
196207
location="l",
197208
gemini_schema=mock_schema,
198-
batch={"enabled": True, "threshold": 1},
209+
batch={
210+
"enabled": True,
211+
"threshold": 1,
212+
"enable_caching": False,
213+
"retention_days": None,
214+
},
199215
)
200216

201217
# Mock _submit_file to verify the request payload contains the schema.
@@ -207,7 +223,7 @@ def test_batch_with_schema(self, mock_client_cls):
207223
self.assertLen(outs, 1)
208224
self.assertEqual(outs[0][0].output, '{"name":"test"}')
209225

210-
# Verify _submit_file was called with correct arguments.
226+
# Verify _submit_file was called with project and location parameters.
211227
mock_submit.assert_called_with(
212228
mock_client,
213229
"gemini-2.5-flash",
@@ -222,6 +238,9 @@ def test_batch_with_schema(self, mock_client_cls):
222238
},
223239
}],
224240
mock.ANY, # Display name contains timestamp/random.
241+
None, # retention_days
242+
"p", # project
243+
"l", # location
225244
)
226245

227246
self.assertEqual(model.gemini_schema.schema_dict, mock_schema.schema_dict)
@@ -238,7 +257,12 @@ def test_batch_error_handling(self, mock_client_cls):
238257
vertexai=True,
239258
project="p",
240259
location="l",
241-
batch={"enabled": True, "threshold": 1},
260+
batch={
261+
"enabled": True,
262+
"threshold": 1,
263+
"enable_caching": False,
264+
"retention_days": None,
265+
},
242266
)
243267

244268
with self.assertRaisesRegex(Exception, "Gemini Batch API error"):
@@ -273,7 +297,12 @@ def test_file_based_ordering(self, mock_client_cls):
273297
vertexai=True,
274298
project="p",
275299
location="l",
276-
batch={"enabled": True, "threshold": 1},
300+
batch={
301+
"enabled": True,
302+
"threshold": 1,
303+
"enable_caching": False,
304+
"retention_days": None,
305+
},
277306
)
278307

279308
results = list(model.infer(prompts))
@@ -354,6 +383,8 @@ def list_blobs_side_effect(prefix=None):
354383
"enabled": True,
355384
"threshold": 1,
356385
"max_prompts_per_job": max_prompts_per_job,
386+
"enable_caching": False,
387+
"retention_days": None,
357388
},
358389
)
359390

@@ -386,7 +417,12 @@ def test_batch_item_error(self, mock_client_cls):
386417
vertexai=True,
387418
project="p",
388419
location="l",
389-
batch={"enabled": True, "threshold": 1},
420+
batch={
421+
"enabled": True,
422+
"threshold": 1,
423+
"enable_caching": False,
424+
"retention_days": None,
425+
},
390426
)
391427

392428
with self.assertRaisesRegex(Exception, "Batch item error"):
@@ -420,7 +456,12 @@ def test_empty_prompts_fast_path(self, mock_client_cls):
420456
prompts=[],
421457
schema_dict=None,
422458
gen_config={},
423-
cfg=gb.BatchConfig(enabled=True, poll_interval=1),
459+
cfg=gb.BatchConfig(
460+
enabled=True,
461+
poll_interval=1,
462+
enable_caching=False,
463+
retention_days=None,
464+
),
424465
)
425466
self.assertEqual(outs, [])
426467

@@ -443,7 +484,13 @@ def test_file_pad_to_expected_count(self, mock_client_cls):
443484
mock_client.batches.create.return_value = job
444485
mock_client.batches.get.return_value = job
445486

446-
cfg = gb.BatchConfig(enabled=True, threshold=1, poll_interval=1)
487+
cfg = gb.BatchConfig(
488+
enabled=True,
489+
threshold=1,
490+
poll_interval=1,
491+
enable_caching=False,
492+
retention_days=None,
493+
)
447494
outs = gb.infer_batch(
448495
client=mock_client,
449496
model_id="m",
@@ -481,6 +528,7 @@ def test_cache_hit_skips_inference(self, mock_client_cls):
481528
enabled=True,
482529
threshold=1,
483530
enable_caching=True,
531+
retention_days=None,
484532
)
485533

486534
outs = gb.infer_batch(
@@ -541,6 +589,7 @@ def get_blob(name):
541589
enabled=True,
542590
threshold=1,
543591
enable_caching=True,
592+
retention_days=None,
544593
)
545594

546595
outs = gb.infer_batch(
@@ -566,6 +615,63 @@ def get_blob(name):
566615
upload_calls, "Should have uploaded new_response to cache"
567616
)
568617

618+
@mock.patch.object(genai, "Client", autospec=True)
619+
@mock.patch.dict("os.environ", {}, clear=True)
620+
def test_project_passed_to_storage_client(self, mock_client_cls):
621+
"""Test that project parameter is passed to storage.Client constructor."""
622+
mock_client = mock_client_cls.return_value
623+
mock_client.vertexai = True
624+
if hasattr(mock_client, "project"):
625+
del mock_client.project
626+
627+
self.mock_storage_client.create_bucket.return_value = self.mock_bucket
628+
629+
output_blob = mock.create_autospec(gb.storage.Blob, instance=True)
630+
output_blob.name = f"output{gb._EXT_JSONL}"
631+
output_blob.open.return_value.__enter__.return_value = io.StringIO(
632+
_create_batch_response(0, {"result": "ok"})
633+
)
634+
self.mock_bucket.list_blobs.return_value = [output_blob]
635+
636+
mock_client.batches.create.return_value = create_mock_batch_job()
637+
mock_client.batches.get.return_value = create_mock_batch_job()
638+
639+
# Create model with specific project and location
640+
test_project = "test-project-123"
641+
test_location = "us-central1"
642+
643+
model = gemini.GeminiLanguageModel(
644+
model_id="gemini-2.5-flash",
645+
vertexai=True,
646+
project=test_project,
647+
location=test_location,
648+
batch={
649+
"enabled": True,
650+
"threshold": 1,
651+
"poll_interval": 0.1,
652+
"enable_caching": False,
653+
"retention_days": None,
654+
},
655+
)
656+
657+
list(model.infer(["test prompt"]))
658+
659+
# Verify storage.Client was called with the correct project parameter.
660+
storage_calls = self.mock_storage_cls.call_args_list
661+
662+
project_calls = [
663+
call
664+
for call in storage_calls
665+
if call.kwargs.get("project") == test_project
666+
]
667+
668+
self.assertGreaterEqual(
669+
len(project_calls),
670+
1,
671+
f"storage.Client should be called with project={test_project}, "
672+
f"but was called with: {[call.kwargs for call in storage_calls]}",
673+
)
674+
569675
def test_cache_hashing_stability(self):
570676
"""Test that hash is stable for same inputs."""
571677
cache = gb.GCSBatchCache("b")

0 commit comments

Comments
 (0)