Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions tensorrt_llm/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,19 @@ def __init__(
if moe_cluster_size == -1:
moe_cluster_size = 1

cp_type = CpType.ULYSSES if cp_config is None else cp_config.get(
"cp_type", CpType.ULYSSES)
# Set default cp_type to ULYSSES.
cp_type = CpType.ULYSSES

# Convert cp_type to CpType enum if it is a string.
if cp_config is not None:
if "cp_type" in cp_config and isinstance(cp_config["cp_type"], str):
try:
cp_config["cp_type"] = CpType[cp_config["cp_type"].upper()]
except KeyError:
raise ValueError(f"Invalid cp_type: {cp_config['cp_type']}. " \
f"Must be one of: {', '.join([t.name for t in CpType])}")
cp_type = cp_config.get("cp_type", CpType.ULYSSES)

moe_world_size = tp_size if cp_type == CpType.ULYSSES else tp_size * cp_size

if moe_tp_size == -1 and moe_ep_size == -1:
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/defs/accuracy/references/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ deepseek-ai/DeepSeek-V3-Lite:
kv_cache_quant_algo: FP8
spec_dec_algo: MTP
accuracy: 64.14
# https://nvbugs/5637012: Currently, BS>1 has accuracy issues with helix for GSM8K.
# BS=1 has expected accuracy but will be too slow for CI testing. So, adding this
# accuracy spec while we investigate the issue.
- extra_acc_spec: helix_with_bs8
accuracy: 50.0
deepseek-ai/DeepSeek-R1:
- quant_algo: NVFP4
accuracy: 95.42
Expand Down
85 changes: 72 additions & 13 deletions tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,17 @@ def launch_disaggregated_llm(
"--backend",
"pytorch",
]
gen_tp, gen_pp = gen_server_config.get(
"tensor_parallel_size",
tensor_parallel_size), gen_server_config.get("pipeline_parallel_size",
1)
ctx_tp, ctx_pp = ctx_server_config.get(
"tensor_parallel_size",
tensor_parallel_size), ctx_server_config.get("pipeline_parallel_size",
1)

ctx_total_gpus = ctx_tp * ctx_pp
gen_total_gpus = gen_tp * gen_pp
gen_tp, gen_pp, gen_cp = gen_server_config.get(
"tensor_parallel_size", tensor_parallel_size), gen_server_config.get(
"pipeline_parallel_size",
1), gen_server_config.get("context_parallel_size", 1)
ctx_tp, ctx_pp, ctx_cp = ctx_server_config.get(
"tensor_parallel_size", tensor_parallel_size), ctx_server_config.get(
"pipeline_parallel_size",
1), ctx_server_config.get("context_parallel_size", 1)

ctx_total_gpus = ctx_tp * ctx_pp * ctx_cp
gen_total_gpus = gen_tp * gen_pp * gen_cp

ctx_urls = disaggregated_server_config["context_servers"]["urls"]
gen_urls = disaggregated_server_config["generation_servers"]["urls"]
Expand All @@ -194,7 +194,7 @@ def launch_disaggregated_llm(
ctx_server_args = ctx_args + [
"--port",
str(port), "--extra_llm_api_options", ctx_server_config_path,
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}", f"--cp_size={ctx_cp}"
]
if "max_num_tokens" in ctx_server_config:
ctx_server_args.append(
Expand All @@ -215,7 +215,7 @@ def launch_disaggregated_llm(
gen_server_args = gen_args + [
"--port",
str(port), "--extra_llm_api_options", gen_server_config_path,
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}", f"--cp_size={gen_cp}"
]
if "max_num_tokens" in gen_server_config:
gen_server_args.append(
Expand Down Expand Up @@ -814,6 +814,65 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
def test_auto_dtype_with_helix(self):
kv_cache_config = {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False,
"enable_partial_reuse": False,
"tokens_per_block": 32,
}
ctx_server_config = {
"pipeline_parallel_size": 1,
"tensor_parallel_size": 2,
"context_parallel_size": 1,
"max_batch_size": 8,
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
"enable_chunked_prefill": False,
"cuda_graph_config": None,
"cache_transceiver_config": {
"backend": "UCX"
},
}
gen_server_config = {
"tensor_parallel_size": 1,
"pipeline_parallel_size": 1,
"context_parallel_size": 2,
"cp_config": {
"cp_type": "HELIX",
"tokens_per_block": 32
},
"max_batch_size": 8,
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
"enable_chunked_prefill": False,
"cuda_graph_config": None,
"cache_transceiver_config": {
"backend": "UCX"
},
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm, extra_acc_spec="helix_with_bs8")

@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(60000)
@parametrize_with_ids("mtp_nextn", [0, 2])
Expand Down
Loading