99import aiohttp
1010import pytest
1111import yaml
12+ from defs .conftest import skip_no_hopper
13+ from defs .disaggregated .test_disaggregated_single_gpu import \
14+ model_path as get_model_path
1215from defs .trt_test_alternative import popen
1316from transformers import AutoTokenizer
1417
1922 KvCacheAwareServerState , ServerRole ,
2023 block_key_hasher )
2124
22- MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
23-
2425
2526def get_ctx_gen_server_urls_from_cfg (config_file : str ):
2627 with open (config_file , 'r' ) as file :
@@ -184,15 +185,17 @@ def __init__(self,
184185 ctx_servers : List [str ],
185186 gen_servers : List [str ],
186187 req_timeout_secs : int = 180 ,
187- server_start_timeout_secs : int = 180 ):
188+ server_start_timeout_secs : int = 180 ,
189+ model_name : str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" ):
188190 super ().__init__ (ctx_servers , gen_servers , req_timeout_secs ,
189191 server_start_timeout_secs )
192+ self .model_name = model_name
190193
191194 async def multi_round_request (self , session : aiohttp .ClientSession ,
192195 init_prompt : str , max_rounds : int ,
193196 threshold : float ):
194197 request = {
195- "model" : MODEL_NAME ,
198+ "model" : self . model_name ,
196199 "prompt" : init_prompt ,
197200 "max_tokens" : 10 ,
198201 "ignore_eos" : True ,
@@ -235,10 +238,15 @@ def __init__(self,
235238 ctx_servers : List [str ],
236239 gen_servers : List [str ],
237240 req_timeout_secs : int = 180 ,
238- server_start_timeout_secs : int = 240 ):
241+ server_start_timeout_secs : int = 240 ,
242+ model_name : str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" ,
243+ model_path : Optional [str ] = None ):
239244 super ().__init__ (ctx_servers , gen_servers , req_timeout_secs ,
240245 server_start_timeout_secs )
241- self .tokenizer = AutoTokenizer .from_pretrained (MODEL_NAME )
246+ if model_path is None :
247+ model_path = get_model_path (model_name )
248+ self .tokenizer = AutoTokenizer .from_pretrained (model_path )
249+ self .model_name = model_name
242250 self .kv_cache_block_maps : dict [str , KvCacheAwareServerState ] = {}
243251 self .kv_cache_event_maps : dict [str , list [dict ]] = {}
244252 for ctx_server in ctx_servers :
@@ -266,7 +274,7 @@ async def multi_round_request(self,
266274 max_rounds : int ,
267275 check_match_count : bool = True ):
268276 request = {
269- "model" : MODEL_NAME ,
277+ "model" : self . model_name ,
270278 "prompt" : init_prompt ,
271279 "max_tokens" : 64 ,
272280 "ignore_eos" : True ,
@@ -347,21 +355,26 @@ def __init__(self,
347355 ctx_servers : List [str ],
348356 gen_servers : List [str ],
349357 req_timeout_secs : int = 180 ,
350- server_start_timeout_secs : int = 180 ):
358+ server_start_timeout_secs : int = 180 ,
359+ model_name : str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" ,
360+ tokens_per_block : int = 32 ):
351361 super ().__init__ (ctx_servers , gen_servers , req_timeout_secs ,
352362 server_start_timeout_secs )
353363 self .ctx_router = KvCacheAwareRouter (server_role = ServerRole .CONTEXT ,
354- servers = ctx_servers )
364+ servers = ctx_servers ,
365+ tokens_per_block = tokens_per_block )
355366 self .gen_router = KvCacheAwareRouter (server_role = ServerRole .GENERATION ,
356- servers = gen_servers )
367+ servers = gen_servers ,
368+ tokens_per_block = tokens_per_block )
369+ self .model_name = model_name
357370
358371 async def multi_round_request (self ,
359372 session : aiohttp .ClientSession ,
360373 init_prompt : str ,
361374 max_rounds : int = 8 ,
362375 check_server_match : bool = True ):
363376 request = {
364- "model" : MODEL_NAME ,
377+ "model" : self . model_name ,
365378 "prompt" : init_prompt ,
366379 "max_tokens" : 64 ,
367380 "ignore_eos" : True ,
@@ -373,7 +386,7 @@ async def multi_round_request(self,
373386 gen_match = 0
374387 for i in range (max_rounds ):
375388 openai_request = CompletionRequest (
376- model = MODEL_NAME ,
389+ model = self . model_name ,
377390 prompt = request ["prompt" ],
378391 disaggregated_params = DisaggregatedParams (
379392 request_type = "context_only" ))
@@ -425,7 +438,7 @@ async def test_eviction(self):
425438 async with await self .new_session () as session :
426439 # send a dummy request for initialization
427440 dummy_request = {
428- "model" : MODEL_NAME ,
441+ "model" : self . model_name ,
429442 "prompt" : [3 ] * 200 ,
430443 "max_tokens" : 1 ,
431444 "ignore_eos" : True ,
@@ -447,7 +460,7 @@ async def test_eviction(self):
447460 logger .info (f"Block pool size: { block_pool_size } " )
448461
449462 # the dummy request can be reused
450- openai_request = CompletionRequest (model = MODEL_NAME ,
463+ openai_request = CompletionRequest (model = self . model_name ,
451464 prompt = dummy_request ["prompt" ])
452465 server , info = await self .gen_router .get_next_server (openai_request )
453466 first_match = info ["matches" ][0 ]
@@ -503,8 +516,7 @@ def load_default_prompts(disaggregated_example_root: str):
503516@contextlib .contextmanager
504517def background_workers (llm_venv , config_file : str , num_ranks : int = None ):
505518 cwd = llm_venv .get_working_directory ()
506-
507- with open (os .path .join (cwd , 'output_workers.log' ), 'w' ) as log_file :
519+ with open (os .path .join (cwd , 'output_workers.log' ), 'w+' ) as log_file :
508520 workers_proc , ctx_servers , gen_servers = run_disaggregated_workers (
509521 config_file = config_file ,
510522 stdout = log_file ,
@@ -537,6 +549,30 @@ def test_workers_conditional_disaggregation(disaggregated_test_root,
537549 asyncio .run (tester .test_multi_round_request (prompts ))
538550
539551
552+ @pytest .mark .parametrize ("deepseek_v3_model_root" , ['DeepSeek-V3-Lite-bf16' ],
553+ indirect = True )
554+ def test_workers_conditional_disaggregation_deepseek_v3_lite_bf16 (
555+ disaggregated_test_root , disaggregated_example_root , llm_venv ,
556+ deepseek_v3_model_root ):
557+ config_file = os .path .join (
558+ disaggregated_test_root ,
559+ 'test_configs/disagg_config_cache_reuse_deepseek_v3.yaml' )
560+ model_root = f"{ llm_venv .get_working_directory ()} /DeepSeek-V3-Lite/bf16"
561+ src_dst_dict = {
562+ deepseek_v3_model_root : model_root ,
563+ }
564+ for src , dst in src_dst_dict .items ():
565+ if not os .path .islink (dst ):
566+ os .makedirs (os .path .dirname (dst ), exist_ok = True )
567+ os .symlink (src , dst , target_is_directory = True )
568+
569+ with background_workers (llm_venv , config_file ,
570+ 2 ) as (ctx_servers , gen_servers ):
571+ tester = ConditionalWorkerTester (ctx_servers , gen_servers )
572+ prompts = load_default_prompts (disaggregated_example_root )
573+ asyncio .run (tester .test_multi_round_request (prompts ))
574+
575+
540576@pytest .mark .parametrize ("llama_model_root" , ['TinyLlama-1.1B-Chat-v1.0' ],
541577 indirect = True )
542578def test_workers_kv_cache_events (disaggregated_test_root ,
@@ -570,6 +606,35 @@ def test_workers_kv_cache_aware_router(disaggregated_test_root,
570606 asyncio .run (tester .test_multi_round_request (prompts , 16 , 4 ))
571607
572608
609+ @skip_no_hopper
610+ @pytest .mark .parametrize ("deepseek_v3_model_root" , ['DeepSeek-V3-Lite-bf16' ],
611+ indirect = True )
612+ def test_workers_kv_cache_aware_router_deepseek_v3_lite_bf16 (
613+ disaggregated_test_root , disaggregated_example_root , llm_venv ,
614+ deepseek_v3_model_root ):
615+ config_file = os .path .join (
616+ disaggregated_test_root ,
617+ 'test_configs/disagg_config_cache_aware_balance_deepseek_v3.yaml' )
618+ model_root = f"{ llm_venv .get_working_directory ()} /DeepSeek-V3-Lite/bf16"
619+ src_dst_dict = {
620+ deepseek_v3_model_root : model_root ,
621+ }
622+ for src , dst in src_dst_dict .items ():
623+ if not os .path .islink (dst ):
624+ os .makedirs (os .path .dirname (dst ), exist_ok = True )
625+ os .symlink (src , dst , target_is_directory = True )
626+
627+ with background_workers (llm_venv , config_file ,
628+ 4 ) as (ctx_servers , gen_servers ):
629+ os .chdir (llm_venv .get_working_directory ())
630+ tester = KvCacheAwareRouterTester (ctx_servers ,
631+ gen_servers ,
632+ model_name = "DeepSeek-V3-Lite/bf16" ,
633+ tokens_per_block = 64 )
634+ prompts = load_default_prompts (disaggregated_example_root )
635+ asyncio .run (tester .test_multi_round_request (prompts , 8 , 4 ))
636+
637+
573638@pytest .mark .parametrize ("llama_model_root" , ['TinyLlama-1.1B-Chat-v1.0' ],
574639 indirect = True )
575640def test_workers_kv_cache_aware_router_eviction (disaggregated_test_root ,
0 commit comments