Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorrt_llm/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def __init__(self,
is_llm_executor=is_llm_executor)

self.init_rpc_executor()
# Inject the generated HMAC key into worker_kwargs for workers
worker_kwargs['hmac_key'] = self.hmac_key
worker_kwargs['rpc_addr'] = self.rpc_addr
self.create_workers(RayGPUWorker, worker_kwargs)
self.setup_engine_remote()
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/executor/ray_gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __init__(
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[BaseLlmArgs] = None,
rpc_addr: Optional[str] = None,
hmac_key: Optional[bytes] = None,
) -> None:
global logger
from tensorrt_llm.logger import logger
Expand All @@ -191,7 +192,7 @@ def __init__(
if rpc_addr is None:
raise RuntimeError(
"RPC mode enabled but no rpc_addr provided to RayGPUWorker")
self.init_rpc_worker(self.global_rank, rpc_addr)
self.init_rpc_worker(self.global_rank, rpc_addr, hmac_key)
self.start_rpc_server()

def setup_engine(self):
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/executor/rpc/rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def __init__(self,
self._client_socket = ZeroMqQueue(address=(address, hmac_key),
is_server=False,
is_async=True,
use_hmac_encryption=False,
use_hmac_encryption=hmac_key
is not None,
socket_type=socket_type,
name="rpc_client")
self._pending_futures = {}
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/executor/rpc/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def bind(self, address: str = "tcp://*:5555") -> None:
self._client_socket = ZeroMqQueue(address=(address, self._hmac_key),
is_server=True,
is_async=True,
use_hmac_encryption=False,
use_hmac_encryption=self._hmac_key
is not None,
socket_type=socket_type,
name="rpc_server")
logger.info(f"RPCServer is bound to {self._address}")
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/executor/rpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(

self._create_mpi_session(model_world_size, mpi_session)

# Inject the generated HMAC key into worker_kwargs for workers
worker_kwargs['hmac_key'] = self.hmac_key
self.worker_kwargs = worker_kwargs

self.launch_workers()
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/executor/rpc_proxy_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import atexit
import json
import os
import threading
from typing import Callable, List, Optional

Expand Down Expand Up @@ -29,7 +30,8 @@ class RpcExecutorMixin:

def init_rpc_executor(self):
self.rpc_addr = get_unique_ipc_addr()
self.rpc_client = RPCClient(self.rpc_addr)
self.hmac_key = os.urandom(32)
self.rpc_client = RPCClient(self.rpc_addr, hmac_key=self.hmac_key)

self._results = {}
self._shutdown_event = threading.Event()
Expand Down
5 changes: 4 additions & 1 deletion tensorrt_llm/executor/rpc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,10 @@ def main_task(
color="yellow")
# Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client
# Set num_workers to larger than 1 since there are some streaming tasks runs infinitely, such as await_responses_async.
rpc_server = RPCServer(worker, num_workers=worker.num_workers)
hmac_key = kwargs.get("hmac_key")
rpc_server = RPCServer(worker,
num_workers=worker.num_workers,
hmac_key=hmac_key)
rpc_server.bind(rpc_addr)
rpc_server.start()
logger_debug(f"[worker] RPC server {mpi_rank()} is started",
Expand Down
5 changes: 3 additions & 2 deletions tensorrt_llm/executor/rpc_worker_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ class RpcWorkerMixin:
# This can be overridden by setting num_workers in the inheriting class
NUM_WORKERS = 6

def init_rpc_worker(self, rank: int, rpc_addr: Optional[str]):
def init_rpc_worker(self, rank: int, rpc_addr: Optional[str], hmac_key: Optional[bytes] = None):
if rpc_addr is None:
raise RuntimeError("RPC mode enabled but no rpc_addr provided to worker")

self.hmac_key = hmac_key
self.rank = rank
self.shutdown_event = Event()
self._response_queue = Queue()
Expand All @@ -41,7 +42,7 @@ def start_rpc_server(self):
if self.rank == 0:
# Use num_workers if set on the instance, otherwise use class default
num_workers = getattr(self, "num_workers", RpcWorkerMixin.NUM_WORKERS)
self.rpc_server = RPCServer(self, num_workers=num_workers)
self.rpc_server = RPCServer(self, num_workers=num_workers, hmac_key=self.hmac_key)
self.rpc_server.bind(self.rpc_addr)
self.rpc_server.start()

Expand Down
37 changes: 37 additions & 0 deletions tests/unittest/executor/test_rpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,43 @@ def test_tp2(self, num_reqs):
assert similar(tokenizer.decode(result.outputs[0].token_ids),
'E F G H I J K L')

def test_hmac_key_generation(self):
"""Test that HMAC key is automatically generated and properly propagated."""
tokenizer = TransformersTokenizer.from_pretrained(model_path)
prompt = "A B C D"
prompt_token_ids = tokenizer.encode(prompt)
max_tokens = 8

with self.create_proxy(tp_size=1) as proxy:
assert proxy.hmac_key is not None, "HMAC key should be generated"
assert len(
proxy.hmac_key
) == 32, f"HMAC key should be 32 bytes, got {len(proxy.hmac_key)}"

# Verify key is properly stored in worker_kwargs
assert 'hmac_key' in proxy.worker_kwargs, "HMAC key should be in worker_kwargs"
assert proxy.worker_kwargs[
'hmac_key'] is not None, "HMAC key in worker_kwargs should not be None"

# Verify both references point to the same key object
assert proxy.hmac_key is proxy.worker_kwargs['hmac_key'], \
"HMAC key should be the same object in both locations"

logger_debug(
f"[Test] HMAC key verified: length={len(proxy.hmac_key)} bytes",
color="green")

# Verify RPC communication works with the generated key
sampling_params = SamplingParams(max_tokens=max_tokens)
result = proxy.generate(prompt_token_ids, sampling_params)
assert similar(
tokenizer.decode(result.outputs[0].token_ids), 'E F G H I J K L'
), "Generation should work with auto-generated HMAC key"

logger_debug(
f"[Test] HMAC key test passed: RPC communication successful",
color="green")


if __name__ == "__main__":
TestRpcProxy().test_tp1(20)
Loading