@@ -102,7 +102,13 @@ def test_batch_routing_vertex(self, mock_client_cls):
102102 vertexai = True ,
103103 project = "test-project" ,
104104 location = gb ._DEFAULT_LOCATION ,
105- batch = {"enabled" : True , "threshold" : 2 , "poll_interval" : 1 },
105+ batch = {
106+ "enabled" : True ,
107+ "threshold" : 2 ,
108+ "poll_interval" : 1 ,
109+ "enable_caching" : False ,
110+ "retention_days" : None ,
111+ },
106112 )
107113 prompts = ["p1" , "p2" ]
108114 outs = list (model .infer (prompts ))
@@ -156,7 +162,12 @@ def test_realtime_when_below_threshold(self, mock_client_cls):
156162 vertexai = True ,
157163 project = "p" ,
158164 location = "l" ,
159- batch = {"enabled" : True , "threshold" : 10 },
165+ batch = {
166+ "enabled" : True ,
167+ "threshold" : 10 ,
168+ "enable_caching" : False ,
169+ "retention_days" : None ,
170+ },
160171 )
161172 outs = list (model .infer (["hello" ]))
162173
@@ -195,7 +206,12 @@ def test_batch_with_schema(self, mock_client_cls):
195206 project = "p" ,
196207 location = "l" ,
197208 gemini_schema = mock_schema ,
198- batch = {"enabled" : True , "threshold" : 1 },
209+ batch = {
210+ "enabled" : True ,
211+ "threshold" : 1 ,
212+ "enable_caching" : False ,
213+ "retention_days" : None ,
214+ },
199215 )
200216
201217 # Mock _submit_file to verify the request payload contains the schema.
@@ -207,7 +223,7 @@ def test_batch_with_schema(self, mock_client_cls):
207223 self .assertLen (outs , 1 )
208224 self .assertEqual (outs [0 ][0 ].output , '{"name":"test"}' )
209225
210- # Verify _submit_file was called with correct arguments .
226+ # Verify _submit_file was called with project and location parameters .
211227 mock_submit .assert_called_with (
212228 mock_client ,
213229 "gemini-2.5-flash" ,
@@ -222,6 +238,9 @@ def test_batch_with_schema(self, mock_client_cls):
222238 },
223239 }],
224240 mock .ANY , # Display name contains timestamp/random.
241+ None , # retention_days
242+ "p" , # project
243+ "l" , # location
225244 )
226245
227246 self .assertEqual (model .gemini_schema .schema_dict , mock_schema .schema_dict )
@@ -238,7 +257,12 @@ def test_batch_error_handling(self, mock_client_cls):
238257 vertexai = True ,
239258 project = "p" ,
240259 location = "l" ,
241- batch = {"enabled" : True , "threshold" : 1 },
260+ batch = {
261+ "enabled" : True ,
262+ "threshold" : 1 ,
263+ "enable_caching" : False ,
264+ "retention_days" : None ,
265+ },
242266 )
243267
244268 with self .assertRaisesRegex (Exception , "Gemini Batch API error" ):
@@ -273,7 +297,12 @@ def test_file_based_ordering(self, mock_client_cls):
273297 vertexai = True ,
274298 project = "p" ,
275299 location = "l" ,
276- batch = {"enabled" : True , "threshold" : 1 },
300+ batch = {
301+ "enabled" : True ,
302+ "threshold" : 1 ,
303+ "enable_caching" : False ,
304+ "retention_days" : None ,
305+ },
277306 )
278307
279308 results = list (model .infer (prompts ))
@@ -354,6 +383,8 @@ def list_blobs_side_effect(prefix=None):
354383 "enabled" : True ,
355384 "threshold" : 1 ,
356385 "max_prompts_per_job" : max_prompts_per_job ,
386+ "enable_caching" : False ,
387+ "retention_days" : None ,
357388 },
358389 )
359390
@@ -386,7 +417,12 @@ def test_batch_item_error(self, mock_client_cls):
386417 vertexai = True ,
387418 project = "p" ,
388419 location = "l" ,
389- batch = {"enabled" : True , "threshold" : 1 },
420+ batch = {
421+ "enabled" : True ,
422+ "threshold" : 1 ,
423+ "enable_caching" : False ,
424+ "retention_days" : None ,
425+ },
390426 )
391427
392428 with self .assertRaisesRegex (Exception , "Batch item error" ):
@@ -420,7 +456,12 @@ def test_empty_prompts_fast_path(self, mock_client_cls):
420456 prompts = [],
421457 schema_dict = None ,
422458 gen_config = {},
423- cfg = gb .BatchConfig (enabled = True , poll_interval = 1 ),
459+ cfg = gb .BatchConfig (
460+ enabled = True ,
461+ poll_interval = 1 ,
462+ enable_caching = False ,
463+ retention_days = None ,
464+ ),
424465 )
425466 self .assertEqual (outs , [])
426467
@@ -443,7 +484,13 @@ def test_file_pad_to_expected_count(self, mock_client_cls):
443484 mock_client .batches .create .return_value = job
444485 mock_client .batches .get .return_value = job
445486
446- cfg = gb .BatchConfig (enabled = True , threshold = 1 , poll_interval = 1 )
487+ cfg = gb .BatchConfig (
488+ enabled = True ,
489+ threshold = 1 ,
490+ poll_interval = 1 ,
491+ enable_caching = False ,
492+ retention_days = None ,
493+ )
447494 outs = gb .infer_batch (
448495 client = mock_client ,
449496 model_id = "m" ,
@@ -481,6 +528,7 @@ def test_cache_hit_skips_inference(self, mock_client_cls):
481528 enabled = True ,
482529 threshold = 1 ,
483530 enable_caching = True ,
531+ retention_days = None ,
484532 )
485533
486534 outs = gb .infer_batch (
@@ -541,6 +589,7 @@ def get_blob(name):
541589 enabled = True ,
542590 threshold = 1 ,
543591 enable_caching = True ,
592+ retention_days = None ,
544593 )
545594
546595 outs = gb .infer_batch (
@@ -566,6 +615,63 @@ def get_blob(name):
566615 upload_calls , "Should have uploaded new_response to cache"
567616 )
568617
618+ @mock .patch .object (genai , "Client" , autospec = True )
619+ @mock .patch .dict ("os.environ" , {}, clear = True )
620+ def test_project_passed_to_storage_client (self , mock_client_cls ):
621+ """Test that project parameter is passed to storage.Client constructor."""
622+ mock_client = mock_client_cls .return_value
623+ mock_client .vertexai = True
624+ if hasattr (mock_client , "project" ):
625+ del mock_client .project
626+
627+ self .mock_storage_client .create_bucket .return_value = self .mock_bucket
628+
629+ output_blob = mock .create_autospec (gb .storage .Blob , instance = True )
630+ output_blob .name = f"output{ gb ._EXT_JSONL } "
631+ output_blob .open .return_value .__enter__ .return_value = io .StringIO (
632+ _create_batch_response (0 , {"result" : "ok" })
633+ )
634+ self .mock_bucket .list_blobs .return_value = [output_blob ]
635+
636+ mock_client .batches .create .return_value = create_mock_batch_job ()
637+ mock_client .batches .get .return_value = create_mock_batch_job ()
638+
639+ # Create model with specific project and location
640+ test_project = "test-project-123"
641+ test_location = "us-central1"
642+
643+ model = gemini .GeminiLanguageModel (
644+ model_id = "gemini-2.5-flash" ,
645+ vertexai = True ,
646+ project = test_project ,
647+ location = test_location ,
648+ batch = {
649+ "enabled" : True ,
650+ "threshold" : 1 ,
651+ "poll_interval" : 0.1 ,
652+ "enable_caching" : False ,
653+ "retention_days" : None ,
654+ },
655+ )
656+
657+ list (model .infer (["test prompt" ]))
658+
659+ # Verify storage.Client was called with the correct project parameter.
660+ storage_calls = self .mock_storage_cls .call_args_list
661+
662+ project_calls = [
663+ call
664+ for call in storage_calls
665+ if call .kwargs .get ("project" ) == test_project
666+ ]
667+
668+ self .assertGreaterEqual (
669+ len (project_calls ),
670+ 1 ,
671+ f"storage.Client should be called with project={ test_project } , "
672+ f"but was called with: { [call .kwargs for call in storage_calls ]} " ,
673+ )
674+
569675 def test_cache_hashing_stability (self ):
570676 """Test that hash is stable for same inputs."""
571677 cache = gb .GCSBatchCache ("b" )
0 commit comments