Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorrt_llm/bench/benchmark/low_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from tensorrt_llm import LLM as PyTorchLLM
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark
from tensorrt_llm.bench.benchmark.utils.general import generate_warmup_dataset
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
Expand Down Expand Up @@ -298,7 +299,20 @@ def latency_command(
kwargs["pytorch_backend_config"].enable_iter_perf_stats = True

if runtime_config.backend == 'pytorch':
if kwargs.pop("extended_runtime_perf_knob_config", None):
logger.warning(
"Ignore extended_runtime_perf_knob_config for pytorch backend."
)
llm = PyTorchLLM(**kwargs)
elif runtime_config.backend == "_autodeploy":
if kwargs.pop("extended_runtime_perf_knob_config", None):
logger.warning(
"Ignore extended_runtime_perf_knob_config for _autodeploy backend."
)
kwargs["world_size"] = kwargs.pop("tensor_parallel_size", None)
kwargs.pop("pipeline_parallel_size", None)

llm = AutoDeployLLM(**kwargs)
else:
llm = LLM(**kwargs)

Expand Down
38 changes: 19 additions & 19 deletions tensorrt_llm/bench/benchmark/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,25 +255,25 @@ def throughput_command(
logger.info("Preparing to run throughput benchmark...")
# Parameters from CLI
# Model, experiment, and engine params
dataset_path: Path = params.pop("dataset")
eos_id: int = params.pop("eos_id")
dataset_path: Path = params.get("dataset")
eos_id: int = params.get("eos_id")
warmup: int = params.get("warmup")
num_requests: int = params.pop("num_requests")
max_seq_len: int = params.pop("max_seq_len")
num_requests: int = params.get("num_requests")
max_seq_len: int = params.get("max_seq_len")
model: str = bench_env.model
checkpoint_path: Path = bench_env.checkpoint_path or bench_env.model
engine_dir: Path = params.pop("engine_dir")
concurrency: int = params.pop("concurrency")
engine_dir: Path = params.get("engine_dir")
concurrency: int = params.get("concurrency")
backend: str = params.get("backend")
modality: str = params.pop("modality")
max_input_len: int = params.pop("max_input_len")
modality: str = params.get("modality")
max_input_len: int = params.get("max_input_len")
model_type = get_model_config(model, checkpoint_path).model_type

# Reporting options
report_json: Path = params.pop("report_json")
output_json: Path = params.pop("output_json")
request_json: Path = params.pop("request_json")
iteration_log: Path = params.pop("iteration_log")
report_json: Path = params.get("report_json")
output_json: Path = params.get("output_json")
request_json: Path = params.get("request_json")
iteration_log: Path = params.get("iteration_log")
iteration_writer = IterationWriter(iteration_log)

# Runtime kwargs and option tracking.
Expand Down Expand Up @@ -340,15 +340,15 @@ def throughput_command(
engine_tokens = exec_settings["settings_config"]["max_num_tokens"]

# Runtime Options
runtime_max_bs = params.pop("max_batch_size")
runtime_max_tokens = params.pop("max_num_tokens")
runtime_max_bs = params.get("max_batch_size")
runtime_max_tokens = params.get("max_num_tokens")
runtime_max_bs = runtime_max_bs or engine_bs
runtime_max_tokens = runtime_max_tokens or engine_tokens
kv_cache_percent = params.pop("kv_cache_free_gpu_mem_fraction")
beam_width = params.pop("beam_width")
streaming: bool = params.pop("streaming")
enable_chunked_context: bool = params.pop("enable_chunked_context")
scheduler_policy: str = params.pop("scheduler_policy")
kv_cache_percent = params.get("kv_cache_free_gpu_mem_fraction")
beam_width = params.get("beam_width")
streaming: bool = params.get("streaming")
enable_chunked_context: bool = params.get("enable_chunked_context")
scheduler_policy: str = params.get("scheduler_policy")

# Update configuration with runtime options
exec_settings["settings_config"]["kv_cache_percent"] = kv_cache_percent
Expand Down