diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index dc7b79c265c..c8518c83a81 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -896,6 +896,10 @@ def _executor_loop_pp(self): def _executor_loop(self): torch.cuda.set_device(self.device_id) + is_ngram = hasattr( + self.model_engine, "spec_config" + ) and self.model_engine.spec_config is not None and self.model_engine.spec_config.spec_dec_mode.is_ngram( + ) with self._profiler() as profile_step: sample_state = None iter_start_time = time.time() @@ -918,8 +922,7 @@ def _executor_loop(self): self._pad_attention_dp_dummy_request() - if self.draft_model_engine is not None or hasattr( - self, 'drafter') and self.drafter is not None: + if self.draft_model_engine is not None or is_ngram or self.drafter is not None: self._prepare_draft_requests(self.active_requests) scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( @@ -1652,8 +1655,13 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests): if req.is_context_only_request and (req.is_context_finished or req.is_finished_due_to_length): self.kv_cache_transceiver.respond_and_send_async(req) - self.resource_manager.resource_managers[ - ResourceManagerType.SEQ_SLOT_MANAGER].free_resources(req) + for resource_mgr_type in ( + ResourceManagerType.SEQ_SLOT_MANAGER, + ResourceManagerType.SPEC_RESOURCE_MANAGER): + if resource_mgr_type in self.resource_manager.resource_managers and self.resource_manager.resource_managers[ + resource_mgr_type] is not None: + self.resource_manager.resource_managers[ + resource_mgr_type].free_resources(req) self.kv_cache_transceiver.check_context_transfer_status(0) diff --git a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py index 540313cfdff..e0ab570ec5c 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py +++ b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py @@ -12,6 +12,7 @@ from tensorrt_llm import LLM, DisaggregatedParams, SamplingParams from tensorrt_llm._utils import set_mpi_comm from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, MpiCommSession +from tensorrt_llm.llmapi.llm_args import EagleDecodingConfig cloudpickle.register_pickle_by_value(sys.modules[__name__]) MPI.pickle.__init__( @@ -33,6 +34,11 @@ def model_path(model_name): elif 'TinyLlama-1.1B-Chat-v1.0' in model_name: return os.path.join(llm_models_root, 'llama-models-v2', 'TinyLlama-1.1B-Chat-v1.0') + elif 'Llama-3.1-8B-Instruct' in model_name: + return os.path.join(llm_models_root, 'llama-3.1-model', + 'Llama-3.1-8B-Instruct/') + elif 'EAGLE3-LLaMA3.1-Instruct-8B' in model_name: + return os.path.join(llm_models_root, 'EAGLE3-LLaMA3.1-Instruct-8B') else: raise ValueError(f"Unknown model: {model_name}") @@ -313,5 +319,104 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph, print("All workers terminated.") +@pytest.mark.parametrize("model", ["Llama-3.1-8B-Instruct"]) +@pytest.mark.parametrize("spec_dec_model_path", ["EAGLE3-LLaMA3.1-Instruct-8B"]) +@pytest.mark.parametrize("generation_overlap", [False]) +def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path, + generation_overlap): + # Test whether the batch slots are properly released when using speculative decoding + # with disaggregated serving. + spec_dec_config = EagleDecodingConfig( + speculative_model_dir=model_path(spec_dec_model_path), + eagle3_one_model=False, + max_draft_len=3) + + worker_pytorch_configs = [] + + # Context worker + worker_pytorch_configs.append( + dict(disable_overlap_scheduler=True, + speculative_config=spec_dec_config, + max_batch_size=1)) + + # Generation worker + worker_pytorch_configs.append( + dict(disable_overlap_scheduler=not generation_overlap, + speculative_config=spec_dec_config, + max_batch_size=1)) + + kv_cache_configs = [ + KvCacheConfig(max_tokens=128, enable_block_reuse=False) + for _ in range(2) + ] + model_names = [model_path(model) for _ in range(2)] + ranks = [0, 1] + worker_args = list( + zip(kv_cache_configs, worker_pytorch_configs, model_names, ranks)) + + port_name = MPI.Open_port() + MPI.Publish_name('my_port', port_name) + + prompt = "What is the capital of Germany?" + + with MPIPoolExecutor(max_workers=2, env={"TRTLLM_USE_MPI_KVCACHE": + "1"}) as executor: + futures = [] + try: + for worker_arg in worker_args: + future = executor.submit(worker_entry_point, *worker_arg) + futures.append(future) + except Exception as e: + print(f"Error in worker {worker_arg}: {e}") + raise e + + try: + print("Launched all the workers.") + intercomm = MPI.COMM_SELF.Accept(port_name) + + for _ in range(2): + intercomm.recv(tag=MPI_READY) + print("Received ready signal.") + max_tokens = 25 + + requests = [] + for _ in range(10): + requests.append( + (prompt, SamplingParams(max_tokens=1, ignore_eos=True), + DisaggregatedParams(request_type="context_only"))) + + intercomm.send(requests, dest=0, tag=MPI_REQUEST) + + for _ in range(len(requests)): + output = intercomm.recv(source=0, tag=MPI_RESULT) + assert output[0].disaggregated_params is not None + assert output[ + 0].disaggregated_params.request_type == "context_only" + assert len(output[0].token_ids) == 1 + + generation_request_disagg_params = output[ + 0].disaggregated_params + generation_request_disagg_params.request_type = "generation_only" + requests = [] + requests.append((prompt, + SamplingParams(max_tokens=max_tokens, + ignore_eos=True), + generation_request_disagg_params)) + + intercomm.send(requests, dest=1, tag=MPI_REQUEST) + output = intercomm.recv(source=1, tag=MPI_RESULT) + + finally: + # Send termination requests + intercomm.send(None, dest=0, tag=MPI_REQUEST) + intercomm.send(None, dest=1, tag=MPI_REQUEST) + print("Sent termination requests to the workers.") + + # Wait for all futures to complete + for future in futures: + future.result() + print("All workers terminated.") + + if __name__ == "__main__": pytest.main() diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index ca678f13ef5..66ce79bb239 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -68,6 +68,7 @@ l0_h100: - disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_llama_context_capacity[False-False-DeepSeek-V3-Lite-fp8/fp8] + - disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_spec_dec_batch_slot_limit[False-EAGLE3-LLaMA3.1-Instruct-8B-Llama-3.1-8B-Instruct] - test_e2e.py::test_trtllm_bench_iteration_log[PyTorch-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B] - test_e2e.py::test_trtllm_bench_iteration_log[PyTorch-non-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B] - test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-]