Skip to content

Commit 7cad59d

Browse files
yuxianqdominicshanshan
authored andcommitted
[None][fix] fix CUDA graph config for test_llm_api_pytorch.py. (#6826)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
1 parent 024887d commit 7cad59d

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,8 @@ def test_fp8(self, cuda_graph, tp_size, pp_size, ep_size):
561561
max_seq_len=8192,
562562
pipeline_parallel_size=pp_size,
563563
moe_expert_parallel_size=ep_size,
564-
use_cuda_graph=cuda_graph) as llm:
564+
cuda_graph_config=CudaGraphConfig()
565+
if cuda_graph else None) as llm:
565566
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
566567
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
567568
task = MMLU(self.MODEL_NAME)
@@ -584,7 +585,8 @@ def test_fp8_chunked_prefill(self, cuda_graph, tp_size, pp_size, ep_size):
584585
moe_expert_parallel_size=ep_size,
585586
enable_chunked_prefill=True,
586587
max_num_tokens=256,
587-
use_cuda_graph=cuda_graph) as llm:
588+
cuda_graph_config=CudaGraphConfig()
589+
if cuda_graph else None) as llm:
588590
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
589591
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
590592
task = MMLU(self.MODEL_NAME)
@@ -704,7 +706,8 @@ def test_fp8_chunked_prefill(self, cuda_graph, tp_size, pp_size, ep_size):
704706
moe_expert_parallel_size=ep_size,
705707
enable_chunked_prefill=True,
706708
max_num_tokens=256,
707-
use_cuda_graph=cuda_graph) as llm:
709+
cuda_graph_config=CudaGraphConfig()
710+
if cuda_graph else None) as llm:
708711
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
709712
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
710713
task = MMLU(self.MODEL_NAME)
@@ -726,7 +729,8 @@ def test_fp4_chunked_prefill(self, cuda_graph, tp_size, pp_size, ep_size):
726729
max_seq_len=22000,
727730
enable_chunked_prefill=True,
728731
max_num_tokens=256,
729-
use_cuda_graph=cuda_graph) as llm:
732+
cuda_graph_config=CudaGraphConfig()
733+
if cuda_graph else None) as llm:
730734
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
731735
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
732736
task = MMLU(self.MODEL_NAME)

0 commit comments

Comments
 (0)