Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions tensorrt_llm/bench/benchmark/low_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
from tensorrt_llm.bench.dataclasses.reporting import ReportUtility
from tensorrt_llm.llmapi import CapacitySchedulerPolicy
from tensorrt_llm.llmapi import BackendType, CapacitySchedulerPolicy
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode

# isort: off
from tensorrt_llm.bench.benchmark.utils.general import get_settings_from_engine, get_settings, ALL_SUPPORTED_BACKENDS
from tensorrt_llm.bench.benchmark.utils.general import get_settings_from_engine, get_settings
# isort: on
from tensorrt_llm.bench.utils.data import (create_dataset_from_stream,
initialize_tokenizer,
Expand All @@ -46,10 +46,11 @@
default=None,
help="Path to a serialized TRT-LLM engine.",
)
@optgroup.option("--backend",
type=click.Choice(ALL_SUPPORTED_BACKENDS),
default="pytorch",
help="The backend to use when running benchmarking.")
@optgroup.option(
"--backend",
type=click.Choice(BackendType.canonical_values()),
default=None,
help="The backend to use when running benchmarking. Default is 'pytorch'.")
@optgroup.option(
"--kv_cache_free_gpu_mem_fraction",
type=float,
Expand Down Expand Up @@ -179,6 +180,11 @@ def latency_command(
"""Run a latency test on a TRT-LLM engine."""

logger.info("Preparing to run latency benchmark...")

params["backend"] = BackendType.get_default_backend_with_warning(
params.get("backend"))
BackendType.print_backend_info(params.get("backend"))

# Parameters from CLI
# Model, experiment, and engine params
dataset_path: Path = params.get("dataset")
Expand All @@ -192,7 +198,7 @@ def latency_command(
modality: str = params.get("modality")
max_input_len: int = params.get("max_input_len")
max_seq_len: int = params.get("max_seq_len")
backend: str = params.get("backend")
backend: BackendType = BackendType(params.get("backend"))
model_type = get_model_config(model, checkpoint_path).model_type

# Runtime Options
Expand Down Expand Up @@ -228,8 +234,7 @@ def latency_command(

# Engine configuration parsing for PyTorch backend
kwargs = {}
if backend and backend.lower() in ALL_SUPPORTED_BACKENDS and backend.lower(
) != "tensorrt":
if backend != BackendType.TENSORRT:
if bench_env.checkpoint_path is None:
snapshot_download(model)

Expand All @@ -238,7 +243,7 @@ def latency_command(
kwargs_max_sql = max_seq_len or metadata.max_sequence_length
logger.info(f"Setting PyTorch max sequence length to {kwargs_max_sql}")
kwargs["max_seq_len"] = kwargs_max_sql
elif backend.lower() == "tensorrt":
else: # TensorRT backend
assert max_seq_len is None, (
"max_seq_len is not a runtime parameter for C++ backend")
exec_settings, build_cfg = get_settings_from_engine(engine_dir)
Expand All @@ -250,10 +255,6 @@ def latency_command(
"dataset contains a maximum sequence of "
f"{metadata.max_sequence_length}. Please rebuild a new engine to"
"support this dataset.")
else:
raise RuntimeError(
f"Invalid backend: {backend}, please use one of the following: "
f"{ALL_SUPPORTED_BACKENDS}")

exec_settings["model"] = model
engine_tokens = exec_settings["settings_config"]["max_num_tokens"]
Expand Down Expand Up @@ -290,7 +291,7 @@ def latency_command(

llm = None
kwargs = kwargs | runtime_config.get_llm_args()
kwargs['backend'] = backend
kwargs['backend'] = backend.canonical_value

try:
logger.info("Setting up latency benchmark.")
Expand Down
34 changes: 21 additions & 13 deletions tensorrt_llm/bench/benchmark/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# isort: off
from tensorrt_llm.bench.benchmark.utils.general import (
get_settings_from_engine, get_settings, ALL_SUPPORTED_BACKENDS)
get_settings_from_engine, get_settings)
# isort: on
from tensorrt_llm import LLM as PyTorchLLM
from tensorrt_llm._tensorrt_engine import LLM
Expand All @@ -28,7 +28,7 @@
from tensorrt_llm.bench.utils.data import (create_dataset_from_stream,
initialize_tokenizer,
update_metadata_for_multimodal)
from tensorrt_llm.llmapi import CapacitySchedulerPolicy
from tensorrt_llm.llmapi import BackendType, CapacitySchedulerPolicy
from tensorrt_llm.logger import logger
from tensorrt_llm.sampling_params import SamplingParams

Expand All @@ -45,10 +45,11 @@
default=None,
help="Path to a serialized TRT-LLM engine.",
)
@optgroup.option("--backend",
type=click.Choice(ALL_SUPPORTED_BACKENDS),
default="pytorch",
help="The backend to use when running benchmarking.")
@optgroup.option(
"--backend",
type=click.Choice(BackendType.canonical_values()),
default=None,
help="The backend to use when running benchmarking. Default is 'pytorch'.")
@optgroup.option(
"--extra_llm_api_options",
type=str,
Expand Down Expand Up @@ -254,6 +255,10 @@ def throughput_command(
"""Run a throughput test on a TRT-LLM engine."""

logger.info("Preparing to run throughput benchmark...")
params["backend"] = BackendType.get_default_backend_with_warning(
params.get("backend"))
BackendType.print_backend_info(params.get("backend"))

# Parameters from CLI
# Model, experiment, and engine params
dataset_path: Path = params.get("dataset")
Expand All @@ -265,7 +270,7 @@ def throughput_command(
checkpoint_path: Path = bench_env.checkpoint_path or bench_env.model
engine_dir: Path = params.get("engine_dir")
concurrency: int = params.get("concurrency")
backend: str = params.get("backend")
backend: BackendType = params.get("backend")
modality: str = params.get("modality")
max_input_len: int = params.get("max_input_len")
model_type = get_model_config(model, checkpoint_path).model_type
Expand Down Expand Up @@ -306,8 +311,7 @@ def throughput_command(
logger.info(metadata.get_summary_for_print())

# Engine configuration parsing
if backend and backend.lower() in ALL_SUPPORTED_BACKENDS and backend.lower(
) != "tensorrt":
if backend != BackendType.TENSORRT:
# If we're dealing with a model name, perform a snapshot download to
# make sure we have a local copy of the model.
if bench_env.checkpoint_path is None:
Expand All @@ -318,7 +322,7 @@ def throughput_command(
kwargs_max_sql = max_seq_len or metadata.max_sequence_length
logger.info(f"Setting PyTorch max sequence length to {kwargs_max_sql}")
kwargs["max_seq_len"] = kwargs_max_sql
elif backend.lower() == "tensorrt":
elif backend == BackendType.TENSORRT:
assert max_seq_len is None, (
"max_seq_len is not a runtime parameter for C++ backend")
exec_settings, build_cfg = get_settings_from_engine(engine_dir)
Expand All @@ -334,7 +338,7 @@ def throughput_command(
else:
raise RuntimeError(
f"Invalid backend: {backend}, please use one of the following: "
"pytorch, tensorrt, _autodeploy.")
f"{BackendType.canonical_values()}")

exec_settings["model"] = model
engine_bs = exec_settings["settings_config"]["max_batch_size"]
Expand Down Expand Up @@ -367,6 +371,10 @@ def throughput_command(
exec_settings["extra_llm_api_options"] = params.pop("extra_llm_api_options")
exec_settings["iteration_log"] = iteration_log

if exec_backend := exec_settings.get("backend", None):
if isinstance(exec_backend, BackendType):
exec_settings["backend"] = exec_backend.canonical_value

# Construct the runtime configuration dataclass.
runtime_config = RuntimeConfig(**exec_settings)
llm = None
Expand All @@ -385,9 +393,9 @@ def ignore_trt_only_args(kwargs: dict):
try:
logger.info("Setting up throughput benchmark.")
kwargs = kwargs | runtime_config.get_llm_args()
kwargs['backend'] = backend
kwargs['backend'] = backend.canonical_value

if backend == "pytorch" and iteration_log is not None:
if backend == BackendType.PYTORCH and iteration_log is not None:
kwargs["enable_iter_perf_stats"] = True

if runtime_config.backend == 'pytorch':
Expand Down
2 changes: 0 additions & 2 deletions tensorrt_llm/bench/benchmark/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
QuantAlgo.NVFP4.value: "fp8",
}

ALL_SUPPORTED_BACKENDS = ["pytorch", "_autodeploy", "tensorrt"]


def get_settings_from_engine(
engine_path: Path
Expand Down
48 changes: 30 additions & 18 deletions tensorrt_llm/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_dict
from tensorrt_llm.llmapi.mpi_session import find_free_port
from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory
from tensorrt_llm.llmapi.utils import BackendType
from tensorrt_llm.logger import logger, severity_map
from tensorrt_llm.serve import OpenAIDisaggServer, OpenAIServer

Expand Down Expand Up @@ -71,7 +72,7 @@ def _signal_handler_cleanup_child(signum, frame):

def get_llm_args(model: str,
tokenizer: Optional[str] = None,
backend: str = "pytorch",
backend: BackendType = BackendType.PYTORCH,
max_beam_width: int = BuildConfig.max_beam_width,
max_batch_size: int = BuildConfig.max_batch_size,
max_num_tokens: int = BuildConfig.max_num_tokens,
Expand Down Expand Up @@ -137,7 +138,7 @@ def get_llm_args(model: str,
"kv_cache_config":
kv_cache_config,
"backend":
backend if backend == "pytorch" else None,
backend if backend == BackendType.PYTORCH else None,
"num_postprocess_workers":
num_postprocess_workers,
"postprocess_tokenizer_dir":
Expand All @@ -157,10 +158,15 @@ def launch_server(host: str,
metadata_server_cfg: Optional[MetadataServerConfig] = None,
server_role: Optional[ServerRole] = None):

backend = llm_args["backend"]
if backend := llm_args.get("backend", None):
if isinstance(backend, BackendType):
llm_args["backend"] = backend.canonical_value
else:
backend = BackendType.get_default_backend_with_warning(backend)

model = llm_args["model"]

if backend == 'pytorch':
if backend == BackendType.PYTORCH:
llm = PyTorchLLM(**llm_args)
else:
llm = LLM(**llm_args)
Expand All @@ -185,10 +191,12 @@ def launch_server(host: str,
default="localhost",
help="Hostname of the server.")
@click.option("--port", type=int, default=8000, help="Port of the server.")
@click.option("--backend",
type=click.Choice(["pytorch", "trt"]),
default="pytorch",
help="Set to 'pytorch' for pytorch path. Default is cpp path.")
@click.option(
"--backend",
type=click.Choice(BackendType.canonical_values()),
default=None,
help=
"Set to 'tensorrt' for TensorRT engine path. Default is the pytorch path.")
@click.option('--log_level',
type=click.Choice(severity_map.keys()),
default='info',
Expand Down Expand Up @@ -277,22 +285,26 @@ def launch_server(host: str,
help=
"Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache."
)
def serve(
model: str, tokenizer: Optional[str], host: str, port: int,
log_level: str, backend: str, max_beam_width: int, max_batch_size: int,
max_num_tokens: int, max_seq_len: int, tp_size: int, pp_size: int,
ep_size: Optional[int], cluster_size: Optional[int],
gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float,
num_postprocess_workers: int, trust_remote_code: bool,
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
metadata_server_config_file: Optional[str], server_role: Optional[str],
fail_fast_on_attention_window_too_large: bool):
def serve(model: str, tokenizer: Optional[str], host: str, port: int,
log_level: str, backend: Optional[str], max_beam_width: int,
max_batch_size: int, max_num_tokens: int, max_seq_len: int,
tp_size: int, pp_size: int, ep_size: Optional[int],
cluster_size: Optional[int], gpus_per_node: Optional[int],
kv_cache_free_gpu_memory_fraction: float,
num_postprocess_workers: int, trust_remote_code: bool,
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
metadata_server_config_file: Optional[str],
server_role: Optional[str],
fail_fast_on_attention_window_too_large: bool):
"""Running an OpenAI API compatible server

MODEL: model name | HF checkpoint path | TensorRT engine path
"""
logger.set_level(log_level)

backend: BackendType = BackendType.get_default_backend_with_warning(backend)
BackendType.print_backend_info(backend)

llm_args, _ = get_llm_args(
model=model,
tokenizer=tokenizer,
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/llmapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
QuantConfig)
from .mpi_session import MpiCommSession
from .utils import BackendType

__all__ = [
'LLM',
Expand Down Expand Up @@ -56,4 +57,5 @@
'TrtLlmArgs',
'AutoDecodingConfig',
'AttentionDpConfig',
'BackendType',
]
2 changes: 1 addition & 1 deletion tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(self,

try:
backend = kwargs.get('backend', None)
if backend == 'pytorch':
if backend in ('pytorch', None):
llm_args_cls = TorchLlmArgs
elif backend == '_autodeploy':
from .._torch.auto_deploy.llm_args import \
Expand Down
52 changes: 52 additions & 0 deletions tensorrt_llm/llmapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import huggingface_hub
import psutil
import torch
from aenum import MultiValueEnum
from huggingface_hub import snapshot_download
from pydantic import BaseModel
from tqdm.auto import tqdm
Expand Down Expand Up @@ -634,3 +635,54 @@ def amend_api_doc_with_status_tags(cls, method: Callable) -> str:

set_api_status = ApiStatusRegistry().set_api_status
get_api_status = ApiStatusRegistry().get_api_status


class BackendType(MultiValueEnum):
""" The backend type. """
# display_name, canonical_value, other aliases ...
PYTORCH = "PyTorch", "pytorch", "PyT"
TENSORRT = "TensorRT", "tensorrt", "trt"
_AUTODEPLOY = "AutoDeploy", "_autodeploy"

def __str__(self):
return self.values[0]

@property
def canonical_value(self):
return self.values[1]

@staticmethod
def default_value() -> str:
""" Default value for the backend. """
return BackendType.PYTORCH.canonical_value

@staticmethod
def canonical_values() -> list[str]:
""" Canonical values for the backend. """
return [v.canonical_value for v in BackendType]

# Several utils for unified behavior across trtllm-serve and trtllm-bench
@staticmethod
def print_backend_info(backend: "BackendType"):
""" Print the backend info. """
logger.info(f"Running with {backend} backend.")

# TODO[Superjom]: Remove this method after v1.0.0 is released.
@staticmethod
def get_default_backend_with_warning(
backend: Optional[str]) -> "BackendType":
""" Warn the user if the backend is not set, as we changed the default
backend to from tensorrt topytorch from v1.0 """
if backend is None:
logger.warning(
f"The default backend becomes 'pytorch' from v1.0, for TensorRT "
"engine, please use `--backend tensorrt` instead.")
backend = BackendType.default_value()

# check the backend should be in the canonical values
if backend not in BackendType.canonical_values():
raise ValueError(
f"Invalid backend: {backend}. Please use one of the following: "
f"{BackendType.canonical_values()}")

return BackendType(backend)
Loading