Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,10 @@ def _executor_loop_pp(self):

def _executor_loop(self):
torch.cuda.set_device(self.device_id)
is_ngram = hasattr(
self.model_engine, "spec_config"
) and self.model_engine.spec_config is not None and self.model_engine.spec_config.spec_dec_mode.is_ngram(
)
with self._profiler() as profile_step:
sample_state = None
iter_start_time = time.time()
Expand All @@ -918,8 +922,7 @@ def _executor_loop(self):

self._pad_attention_dp_dummy_request()

if self.draft_model_engine is not None or hasattr(
self, 'drafter') and self.drafter is not None:
if self.draft_model_engine is not None or is_ngram or self.drafter is not None:
self._prepare_draft_requests(self.active_requests)

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
Expand Down Expand Up @@ -1652,8 +1655,13 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests):
if req.is_context_only_request and (req.is_context_finished or
req.is_finished_due_to_length):
self.kv_cache_transceiver.respond_and_send_async(req)
self.resource_manager.resource_managers[
ResourceManagerType.SEQ_SLOT_MANAGER].free_resources(req)
for resource_mgr_type in (
ResourceManagerType.SEQ_SLOT_MANAGER,
ResourceManagerType.SPEC_RESOURCE_MANAGER):
if resource_mgr_type in self.resource_manager.resource_managers and self.resource_manager.resource_managers[
resource_mgr_type] is not None:
self.resource_manager.resource_managers[
resource_mgr_type].free_resources(req)

self.kv_cache_transceiver.check_context_transfer_status(0)

Expand Down
105 changes: 105 additions & 0 deletions tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tensorrt_llm import LLM, DisaggregatedParams, SamplingParams
from tensorrt_llm._utils import set_mpi_comm
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, MpiCommSession
from tensorrt_llm.llmapi.llm_args import EagleDecodingConfig

cloudpickle.register_pickle_by_value(sys.modules[__name__])
MPI.pickle.__init__(
Expand All @@ -33,6 +34,11 @@ def model_path(model_name):
elif 'TinyLlama-1.1B-Chat-v1.0' in model_name:
return os.path.join(llm_models_root, 'llama-models-v2',
'TinyLlama-1.1B-Chat-v1.0')
elif 'Llama-3.1-8B-Instruct' in model_name:
return os.path.join(llm_models_root, 'llama-3.1-model',
'Llama-3.1-8B-Instruct/')
elif 'EAGLE3-LLaMA3.1-Instruct-8B' in model_name:
return os.path.join(llm_models_root, 'EAGLE3-LLaMA3.1-Instruct-8B')
else:
raise ValueError(f"Unknown model: {model_name}")

Expand Down Expand Up @@ -313,5 +319,104 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
print("All workers terminated.")


@pytest.mark.parametrize("model", ["Llama-3.1-8B-Instruct"])
@pytest.mark.parametrize("spec_dec_model_path", ["EAGLE3-LLaMA3.1-Instruct-8B"])
@pytest.mark.parametrize("generation_overlap", [False])
def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
generation_overlap):
# Test whether the batch slots are properly released when using speculative decoding
# with disaggregated serving.
spec_dec_config = EagleDecodingConfig(
speculative_model_dir=model_path(spec_dec_model_path),
eagle3_one_model=False,
max_draft_len=3)

worker_pytorch_configs = []

# Context worker
worker_pytorch_configs.append(
dict(disable_overlap_scheduler=True,
speculative_config=spec_dec_config,
max_batch_size=1))

# Generation worker
worker_pytorch_configs.append(
dict(disable_overlap_scheduler=not generation_overlap,
speculative_config=spec_dec_config,
max_batch_size=1))

kv_cache_configs = [
KvCacheConfig(max_tokens=128, enable_block_reuse=False)
for _ in range(2)
]
model_names = [model_path(model) for _ in range(2)]
ranks = [0, 1]
worker_args = list(
zip(kv_cache_configs, worker_pytorch_configs, model_names, ranks))

port_name = MPI.Open_port()
MPI.Publish_name('my_port', port_name)

prompt = "What is the capital of Germany?"

with MPIPoolExecutor(max_workers=2, env={"TRTLLM_USE_MPI_KVCACHE":
"1"}) as executor:
futures = []
try:
for worker_arg in worker_args:
future = executor.submit(worker_entry_point, *worker_arg)
futures.append(future)
except Exception as e:
print(f"Error in worker {worker_arg}: {e}")
raise e

try:
print("Launched all the workers.")
intercomm = MPI.COMM_SELF.Accept(port_name)

for _ in range(2):
intercomm.recv(tag=MPI_READY)
print("Received ready signal.")
max_tokens = 25

requests = []
for _ in range(10):
requests.append(
(prompt, SamplingParams(max_tokens=1, ignore_eos=True),
DisaggregatedParams(request_type="context_only")))

intercomm.send(requests, dest=0, tag=MPI_REQUEST)

for _ in range(len(requests)):
output = intercomm.recv(source=0, tag=MPI_RESULT)
assert output[0].disaggregated_params is not None
assert output[
0].disaggregated_params.request_type == "context_only"
assert len(output[0].token_ids) == 1

generation_request_disagg_params = output[
0].disaggregated_params
generation_request_disagg_params.request_type = "generation_only"
requests = []
requests.append((prompt,
SamplingParams(max_tokens=max_tokens,
ignore_eos=True),
generation_request_disagg_params))

intercomm.send(requests, dest=1, tag=MPI_REQUEST)
output = intercomm.recv(source=1, tag=MPI_RESULT)

finally:
# Send termination requests
intercomm.send(None, dest=0, tag=MPI_REQUEST)
intercomm.send(None, dest=1, tag=MPI_REQUEST)
print("Sent termination requests to the workers.")

# Wait for all futures to complete
for future in futures:
future.result()
print("All workers terminated.")


if __name__ == "__main__":
pytest.main()
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ l0_h100:
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_llama_context_capacity[False-False-DeepSeek-V3-Lite-fp8/fp8]
- disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_spec_dec_batch_slot_limit[False-EAGLE3-LLaMA3.1-Instruct-8B-Llama-3.1-8B-Instruct]
- test_e2e.py::test_trtllm_bench_iteration_log[PyTorch-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B]
- test_e2e.py::test_trtllm_bench_iteration_log[PyTorch-non-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B]
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-]
Expand Down