Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
abc9bb9
Update to pull LLM from a central location.
FrankD412 Jul 29, 2025
509c0c8
Fix missing iteration log set up.
FrankD412 Jul 29, 2025
fbfe59a
Add missing argument.
FrankD412 Jul 29, 2025
12755e0
Update to fix TRT property removal.
FrankD412 Jul 29, 2025
4530b65
Add docstrings.
FrankD412 Jul 29, 2025
28cc594
Add missing param to tput benchmark.
FrankD412 Jul 29, 2025
29eac89
Refactor code to use centralized options.
FrankD412 Aug 4, 2025
35668cd
Updates to re-organize code.
FrankD412 Aug 4, 2025
5c5cd3d
Update iteration_writer property.
FrankD412 Aug 4, 2025
62003bc
Fix model_config property.
FrankD412 Aug 4, 2025
400539e
Fix throughput tokenizer.
FrankD412 Aug 4, 2025
882a22f
Fix tokenizer field.
FrankD412 Aug 4, 2025
f85c521
Fix incorrection initialize_tokenizer arg.
FrankD412 Aug 5, 2025
78b0e55
Add aliases.
FrankD412 Aug 5, 2025
a3f6576
More clean up.
FrankD412 Aug 5, 2025
d2fade2
Update to iterlog options.
FrankD412 Aug 7, 2025
b60f369
Update to iterlog options.
FrankD412 Aug 7, 2025
4a0884f
Create reporting utility for JSON.
FrankD412 Aug 18, 2025
b49c06c
Add streaming.
FrankD412 Aug 18, 2025
f7ee3f2
Pre-commit run.
FrankD412 Aug 19, 2025
1c956a5
Pre-commit run.
FrankD412 Aug 19, 2025
ae99480
Fix cursor tab complete.
FrankD412 Aug 19, 2025
bf4ff5b
Remove func name.
FrankD412 Aug 19, 2025
b0615de
Fix duplicate iteration log instances.
FrankD412 Aug 20, 2025
4936695
Merge branch 'main' into fdinatale/trtllm-bench/update_low_trt_opts
FrankD412 Aug 21, 2025
51e0e2f
Merge branch 'main' into fdinatale/trtllm-bench/update_low_trt_opts
FrankD412 Aug 21, 2025
cac35c8
Merge branch 'main' into fdinatale/trtllm-bench/update_low_trt_opts
FrankD412 Aug 21, 2025
2a7cade
Merge branch 'main' into fdinatale/trtllm-bench/update_low_trt_opts
FrankD412 Aug 21, 2025
55115bb
Merge branch 'main' into fdinatale/trtllm-bench/update_low_trt_opts
FrankD412 Aug 21, 2025
6190869
Merge branch 'main' into fdinatale/trtllm-bench/update_low_trt_opts
FrankD412 Aug 22, 2025
3251ec1
Merge branch 'main' into fdinatale/trtllm-bench/update_low_trt_opts
FrankD412 Aug 22, 2025
b20449d
Merge branch 'main' into fdinatale/trtllm-bench/update_low_trt_opts
FrankD412 Aug 23, 2025
239315a
Merge branch 'main' into fdinatale/trtllm-bench/update_low_trt_opts
FrankD412 Aug 25, 2025
8c70ba7
Merge branch 'main' into fdinatale/trtllm-bench/update_low_trt_opts
FrankD412 Aug 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions tensorrt_llm/bench/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import json
from pathlib import Path
from typing import Callable, Dict, Optional

from pydantic import AliasChoices, BaseModel, Field

from tensorrt_llm import LLM as PyTorchLLM
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
from tensorrt_llm.bench.build.build import get_model_config
from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
from tensorrt_llm.logger import logger


class GeneralExecSettings(BaseModel):
model_config = {
"extra": "ignore"
} # Ignore extra fields not defined in the model

backend: str = Field(
default="pytorch",
description="The backend to use when running benchmarking")
beam_width: int = Field(default=1, description="Number of search beams")
model_path: Optional[Path] = Field(default=None,
description="Path to model checkpoint")
concurrency: int = Field(
default=-1, description="Desired concurrency rate, <=0 for no limit")
dataset_path: Optional[Path] = Field(default=None,
validation_alias=AliasChoices(
"dataset_path", "dataset"),
description="Path to dataset file")
engine_dir: Optional[Path] = Field(
default=None, description="Path to a serialized TRT-LLM engine")
eos_id: int = Field(
default=-1, description="End-of-sequence token ID, -1 to disable EOS")
iteration_log: Optional[Path] = Field(
default=None, description="Path where iteration logging is written")
kv_cache_percent: float = Field(
default=0.90,
validation_alias=AliasChoices("kv_cache_percent",
"kv_cache_free_gpu_mem_fraction"),
description="Percentage of memory for KV Cache after model load")
max_input_len: int = Field(default=4096,
description="Maximum input sequence length")
max_seq_len: Optional[int] = Field(default=None,
description="Maximum sequence length")
modality: Optional[str] = Field(
default=None, description="Modality of multimodal requests")
model: Optional[str] = Field(default=None, description="Model name or path")
num_requests: int = Field(
default=0, description="Number of requests to cap benchmark run at")
output_json: Optional[Path] = Field(
default=None, description="Path where output should be written")
report_json: Optional[Path] = Field(
default=None, description="Path where report should be written")
request_json: Optional[Path] = Field(
default=None,
description="Path where per request information is written")
streaming: bool = Field(default=False,
description="Whether to use streaming mode")
warmup: int = Field(default=2,
description="Number of requests to warm up benchmark")

@property
def iteration_writer(self) -> IterationWriter:
return IterationWriter(self.iteration_log)

@property
def model_type(self) -> str:
return get_model_config(self.model, self.checkpoint_path).model_type

@property
def checkpoint_path(self) -> Path:
return self.model_path or self.model


def ignore_trt_only_args(kwargs: dict, backend: str):
"""Ignore TensorRT-only arguments for non-TensorRT backends.

Args:
kwargs: Dictionary of keyword arguments to be passed to the LLM constructor.
backend: The backend type (e.g., "pytorch", "_autodeploy").
"""
trt_only_args = [
"batching_type",
"normalize_log_probs",
"extended_runtime_perf_knob_config",
]
for arg in trt_only_args:
if kwargs.pop(arg, None):
logger.warning(f"Ignore {arg} for {backend} backend.")


def get_llm(runtime_config: RuntimeConfig, kwargs: dict):
"""Create and return an appropriate LLM instance based on the backend configuration.

Args:
runtime_config: Runtime configuration containing backend selection and settings.
kwargs: Additional keyword arguments to pass to the LLM constructor.

Returns:
An instance of the appropriate LLM class for the specified backend.
"""
llm_cls = LLM

if runtime_config.backend != "tensorrt":
ignore_trt_only_args(kwargs, runtime_config.backend)

if runtime_config.backend == 'pytorch':
llm_cls = PyTorchLLM

if runtime_config.iteration_log is not None:
kwargs["enable_iter_perf_stats"] = True

elif runtime_config.backend == "_autodeploy":
kwargs["world_size"] = kwargs.pop("tensor_parallel_size", None)
llm_cls = AutoDeployLLM

llm = llm_cls(**kwargs)
return llm


def get_general_cli_options(
params: Dict, bench_env: BenchmarkEnvironment) -> GeneralExecSettings:
"""Get general execution settings from command line parameters.

Args:
params: Dictionary of command line parameters.
bench_env: Benchmark environment containing model and checkpoint information.

Returns:
An instance of GeneralExecSettings containing general execution settings.
"""
# Create a copy of params to avoid modifying the original
settings_dict = params.copy()

# Add derived values that need to be computed from bench_env
model_path = bench_env.checkpoint_path
model = bench_env.model
# Override/add the computed values
settings_dict.update({
"model_path": model_path,
"model": model,
})

# Create and return the settings object, ignoring any extra fields
return GeneralExecSettings(**settings_dict)


def generate_json_report(report_path: Optional[Path], func: Callable):
if report_path is None:
logger.debug("No report path provided, skipping report generation.")
else:
logger.info(f"Writing report information to {report_path}...")
with open(report_path, "w") as f:
f.write(json.dumps(func(), indent=4))
logger.info(f"Report information written to {report_path}.")
Loading