diff --git a/Cargo.lock b/Cargo.lock index 6dd34b06ca..67bef6ce8c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1619,6 +1619,7 @@ dependencies = [ "cudarc 0.16.2", "derive-getters", "derive_builder", + "dialoguer", "dynamo-runtime", "either", "erased-serde", @@ -1627,6 +1628,7 @@ dependencies = [ "galil-seiferas", "ggus", "hf-hub", + "humantime", "insta", "itertools 0.14.0", "lazy_static", @@ -1677,14 +1679,12 @@ dependencies = [ "async-stream", "async-trait", "clap", - "dialoguer", "dynamo-engine-llamacpp", "dynamo-engine-mistralrs", "dynamo-llm", "dynamo-runtime", "futures", "futures-util", - "humantime", "libc", "regex", "serde", diff --git a/components/router/src/main.rs b/components/router/src/main.rs index 3546a9bb30..1e94354746 100644 --- a/components/router/src/main.rs +++ b/components/router/src/main.rs @@ -47,7 +47,7 @@ struct Args { /// Block size for the router #[arg(long)] - block_size: usize, + block_size: u32, } fn main() -> Result<()> { @@ -88,7 +88,7 @@ impl WorkerSelector for CustomWorkerSelector { &self, workers: &ProcessedEndpoints, request: &SchedulingRequest, - block_size: usize, + block_size: u32, ) -> Result { // customize logic here // F12 into [DefaultWorkerSelector] to see the original logic diff --git a/launch/dynamo-run/Cargo.toml b/launch/dynamo-run/Cargo.toml index 0cceb324d8..ae3750b652 100644 --- a/launch/dynamo-run/Cargo.toml +++ b/launch/dynamo-run/Cargo.toml @@ -34,7 +34,6 @@ anyhow = { workspace = true } async-stream = { workspace = true } async-trait = { workspace = true } futures = { workspace = true } -humantime = { workspace = true } libc = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } @@ -47,7 +46,6 @@ uuid = { workspace = true } async-openai = { workspace = true } clap = { version = "4.5", features = ["derive", "env"] } -dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] } futures-util = { version = "0.3" } regex = "1" diff --git a/launch/dynamo-run/src/flags.rs b/launch/dynamo-run/src/flags.rs index 2ac7286302..bc9986d596 100644 --- a/launch/dynamo-run/src/flags.rs +++ b/launch/dynamo-run/src/flags.rs @@ -17,9 +17,13 @@ use std::collections::HashMap; use std::path::PathBuf; use clap::ValueEnum; +use dynamo_llm::entrypoint::RouterConfig; use dynamo_llm::kv_router::KvRouterConfig; +use dynamo_llm::local_model::LocalModel; use dynamo_runtime::pipeline::RouterMode as RuntimeRouterMode; +use crate::Output; + /// Required options depend on the in and out choices #[derive(clap::Parser, Debug, Clone)] #[command(version, about, long_about = None)] @@ -125,11 +129,11 @@ pub struct Flags { /// context length (e.g. Llama 4). /// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json. #[arg(long)] - pub context_length: Option, + pub context_length: Option, /// KV cache block size (vllm only) #[arg(long)] - pub kv_cache_block_size: Option, + pub kv_cache_block_size: Option, /// Additional engine-specific arguments from a JSON file. /// Contains a mapping of parameter names to values. @@ -154,66 +158,63 @@ pub struct Flags { } impl Flags { - /// Get KV router configuration - pub fn kv_router_config(&self) -> KvRouterConfig { - KvRouterConfig::new( - self.kv_overlap_score_weight, - self.kv_gpu_cache_usage_weight, - self.kv_waiting_requests_weight, - ) + /// For each Output variant, check if it would be able to run. + /// This takes validation out of the main engine creation path. + pub fn validate(&self, local_model: &LocalModel, out_opt: &Output) -> anyhow::Result<()> { + match out_opt { + Output::Dynamic => { + if self.context_length.is_some() { + anyhow::bail!("'--context-length' flag should only be used on the worker node, not on the ingress"); + } + if self.kv_cache_block_size.is_some() { + anyhow::bail!("'--kv-cache-block-size' flag should only be used on the worker node, not on the ingress"); + } + } + Output::EchoFull => {} + Output::EchoCore => { + if !local_model.card().has_tokenizer() { + anyhow::bail!( + "out=echo_core need to find the tokenizer. Pass flag --model-path " + ); + }; + } + #[cfg(feature = "mistralrs")] + Output::MistralRs => {} + Output::SgLang => { + if !local_model.path().is_dir() { + // TODO GGUF support for sglang: https://github.com/ai-dynamo/dynamo/issues/572 + anyhow::bail!("`--model-path should point at a HuggingFace repo checkout"); + } + } + Output::Vllm => { + if self.base_gpu_id != 0 { + anyhow::bail!("vllm does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead."); + } + } + Output::Trtllm => { + if self.base_gpu_id != 0 { + anyhow::bail!("TRTLLM does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead."); + } + } + #[cfg(feature = "llamacpp")] + Output::LlamaCpp => { + if !local_model.path().is_file() { + anyhow::bail!("--model-path should refer to a GGUF file. llama_cpp does not support safetensors."); + } + } + } + Ok(()) } - /// Convert the flags back to a command line. Including only the non-null values, but - /// include the defaults. Includes the canonicalized model path and normalized model name. - /// - /// Used to pass arguments to python engines via `pystr` and `pytok`. - pub fn as_vec(&self, path: &str, name: &str) -> Vec { - let mut out = vec![ - "--model-path".to_string(), - path.to_string(), - "--model-name".to_string(), - name.to_string(), - "--http-port".to_string(), - self.http_port.to_string(), - // Default 1 - "--tensor-parallel-size".to_string(), - self.tensor_parallel_size.to_string(), - // Default 0 - "--base-gpu-id".to_string(), - self.base_gpu_id.to_string(), - // Default 1 - "--num-nodes".to_string(), - self.num_nodes.to_string(), - // Default 0 - "--node-rank".to_string(), - self.node_rank.to_string(), - ]; - if let Some(model_config_path) = self.model_config.as_ref() { - out.push("--model-config".to_string()); - out.push(model_config_path.display().to_string()); - } - if let Some(leader) = self.leader_addr.as_ref() { - out.push("--leader-addr".to_string()); - out.push(leader.to_string()); - } - if let Some(extra_engine_args) = self.extra_engine_args.as_ref() { - out.push("--extra-engine-args".to_string()); - out.push(extra_engine_args.display().to_string()); - } - if let Some(weight) = self.kv_overlap_score_weight { - out.push("--kv-overlap-score-weight".to_string()); - out.push(weight.to_string()); - } - if let Some(weight) = self.kv_gpu_cache_usage_weight { - out.push("--kv-gpu-cache-usage-weight".to_string()); - out.push(weight.to_string()); - } - if let Some(weight) = self.kv_waiting_requests_weight { - out.push("--kv-waiting-requests-weight".to_string()); - out.push(weight.to_string()); - } - out.extend(self.last.clone()); - out + pub fn router_config(&self) -> RouterConfig { + RouterConfig::new( + self.router_mode.into(), + KvRouterConfig::new( + self.kv_overlap_score_weight, + self.kv_gpu_cache_usage_weight, + self.kv_waiting_requests_weight, + ), + ) } /// Load extra engine arguments from a JSON file diff --git a/launch/dynamo-run/src/input.rs b/launch/dynamo-run/src/input.rs deleted file mode 100644 index 25d294b279..0000000000 --- a/launch/dynamo-run/src/input.rs +++ /dev/null @@ -1,20 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -pub mod batch; -mod common; -pub mod endpoint; -pub mod http; -pub mod text; diff --git a/launch/dynamo-run/src/lib.rs b/launch/dynamo-run/src/lib.rs index 82f0841a28..f9cf8bc7dd 100644 --- a/launch/dynamo-run/src/lib.rs +++ b/launch/dynamo-run/src/lib.rs @@ -1,328 +1,184 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use std::time::Duration; use std::{future::Future, pin::Pin}; -use std::{io::Read, sync::Arc, time::Duration}; -use anyhow::Context; -use dynamo_llm::{backend::ExecutionContext, engines::StreamingEngine, local_model::LocalModel}; -use dynamo_runtime::protocols::Endpoint as EndpointId; -use dynamo_runtime::slug::Slug; -use dynamo_runtime::{CancellationToken, DistributedRuntime}; +use anyhow::Context as _; +use dynamo_llm::entrypoint::input::Input; +use dynamo_llm::entrypoint::EngineConfig; +use dynamo_llm::local_model::{LocalModel, LocalModelBuilder}; +use dynamo_runtime::CancellationToken; mod flags; pub use flags::Flags; -mod input; mod opt; pub use dynamo_llm::request_template::RequestTemplate; -pub use opt::{Input, Output}; +pub use opt::Output; mod subprocess; const CHILD_STOP_TIMEOUT: Duration = Duration::from_secs(2); -/// Default size of a KV cache block. Override with --kv-cache-block-size -const DEFAULT_KV_CACHE_BLOCK_SIZE: usize = 16; - -pub enum EngineConfig { - /// Remote networked engines - Dynamic, - - /// A Full service engine does it's own tokenization and prompt formatting. - StaticFull { - engine: Arc, - model: Box, - }, - - /// A core engine expects to be wrapped with pre/post processors that handle tokenization. - StaticCore { - engine: ExecutionContext, - model: Box, - }, -} - -fn is_in_dynamic(in_opt: &Input) -> bool { - matches!(in_opt, Input::Endpoint(_)) -} - -fn is_out_dynamic(out_opt: &Option) -> bool { - matches!(out_opt, Some(Output::Dynamic)) -} - pub async fn run( runtime: dynamo_runtime::Runtime, in_opt: Input, out_opt: Option, flags: Flags, ) -> anyhow::Result<()> { - if is_in_dynamic(&in_opt) && is_out_dynamic(&out_opt) { - anyhow::bail!("Cannot use endpoint for both in and out"); - } - - let cancel_token = runtime.primary_token(); - let maybe_path = flags - .model_path_pos - .clone() - .or(flags.model_path_flag.clone()); - - let mut local_model: LocalModel = if is_out_dynamic(&out_opt) { - // If output is dynamic we are ingress and don't have a local model, but making an - // empty one cleans up the code. - Default::default() - } else { - // All other output types have a local model - match &maybe_path { - Some(model_path) => { - LocalModel::prepare( - model_path.to_str().context("Invalid UTF-8 in model path")?, - flags.model_config.as_deref(), - flags.model_name.clone(), - ) - .await? - } - None => { - // echo_full engine doesn't need a path - match &flags.model_name { - Some(name) => LocalModel::with_name_only(name), - None => Default::default(), - } - } - } + // + // Configure + // + + let mut builder = LocalModelBuilder::default(); + builder + .model_path( + flags + .model_path_pos + .clone() + .or(flags.model_path_flag.clone()), + ) + .model_name(flags.model_name.clone()) + .kv_cache_block_size(flags.kv_cache_block_size) + // Only set if user provides. Usually loaded from tokenizer_config.json + .context_length(flags.context_length) + .http_port(flags.http_port) + .router_config(flags.router_config()) + .request_template(flags.request_template.clone()); + + // If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint. + // If not, then the endpoint isn't exposed so we let LocalModel invent one. + if let Input::Endpoint(path) = &in_opt { + builder.endpoint_id(path.parse().with_context(|| path.clone())?); }; - // Only set if user provides. Usually loaded from tokenizer_config.json - if let Some(context_length) = flags.context_length { - local_model.set_context_length(context_length); - } - // Always set, there is no engine provided default - local_model.set_kv_cache_block_size( - flags - .kv_cache_block_size - .unwrap_or(DEFAULT_KV_CACHE_BLOCK_SIZE), - ); - - let mut extra: Option + Send>>> = None; // vllm and sglang sub-process - - let template = if let Some(path) = flags.request_template.as_ref() { - let template = RequestTemplate::load(path)?; - tracing::debug!("Using request template: {template:?}"); - Some(template) - } else { - None - }; + let local_model = builder.build().await?; - // We may need it later - let card = local_model.card().clone(); + // + // Create an engine + // - let out_opt = out_opt.unwrap_or_else(|| { - let default_engine = if card.is_gguf() { - gguf_default() - } else { - safetensors_default() - }; - tracing::info!( - "Using default engine: {default_engine}. Use out= to specify one of {}", - Output::available_engines().join(", ") - ); - default_engine - }); + let out_opt = out_opt.unwrap_or_else(|| default_engine_for(&local_model)); print_cuda(&out_opt); - // Create the engine matching `out` - let engine_config = match out_opt { - Output::Dynamic => { - // Sanity check - TODO probably make a general sanity check at start of method - if flags.context_length.is_some() { - anyhow::bail!("'--content-length' flag should only be used on the worker node, not on the ingress"); - } - if flags.kv_cache_block_size.is_some() { - anyhow::bail!("'--kv-cache-block-size' flag should only be used on the worker node, not on the ingress"); - } - EngineConfig::Dynamic - } - Output::EchoFull => EngineConfig::StaticFull { - model: Box::new(local_model), - engine: dynamo_llm::engines::make_engine_full(), - }, - Output::EchoCore => { - let card = local_model.card(); - if !card.has_tokenizer() { - anyhow::bail!( - "out=echo_core need to find the tokenizer. Pass flag --model-path " - ); - }; - EngineConfig::StaticCore { - engine: dynamo_llm::engines::make_engine_core(), - model: Box::new(local_model), - } - } - #[cfg(feature = "mistralrs")] - Output::MistralRs => EngineConfig::StaticFull { - engine: dynamo_engine_mistralrs::make_engine(&local_model).await?, - model: Box::new(local_model), - }, - Output::SgLang => { - if !local_model.path().is_dir() { - // TODO Does sglang support GGUF? Can we make it work? - anyhow::bail!("`--model-path should point at a HuggingFace repo checkout"); - } - - // If `in=dyn` we want the sglang subprocess to listen on that endpoint. - // If not, then the endpoint isn't exposed so we invent an internal one. - let endpoint = match &in_opt { - Input::Endpoint(path) => path.parse()?, - _ => internal_endpoint("sglang"), - }; - - let multi_node_conf = dynamo_llm::engines::MultiNodeConfig { - num_nodes: flags.num_nodes, - node_rank: flags.node_rank, - leader_addr: flags.leader_addr.clone().unwrap_or_default(), - }; - let (py_script, child) = match subprocess::start( - subprocess::sglang::PY, - &local_model, - &endpoint, - flags.clone(), - if flags.num_nodes <= 1 { - None - } else { - Some(multi_node_conf) - }, - ) - .await - { - Ok(x) => x, - Err(err) => { - anyhow::bail!("Failed starting sglang sub-process: {err}"); - } - }; - let cancel_token = cancel_token.clone(); - - // Sub-process cleanup - extra = Some(Box::pin(async move { - stopper(cancel_token, child, py_script).await; - })); - EngineConfig::Dynamic - } - Output::Vllm => { - if flags.base_gpu_id != 0 { - anyhow::bail!("vllm does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead."); - } + // Now that we know the output we're targeting, check if we expect it to work + flags.validate(&local_model, &out_opt)?; - // If `in=dyn` we want the vllm subprocess to listen on that endpoint. - // If not, then the endpoint isn't exposed so we invent an internal one. - let endpoint = match &in_opt { - Input::Endpoint(path) => path.parse()?, - _ => internal_endpoint("vllm"), - }; + // Make an engine from the local_model, flags and output. + let (engine_config, extra) = + engine_for(runtime.primary_token(), out_opt, flags.clone(), local_model).await?; - let (py_script, child) = match subprocess::start( - subprocess::vllm::PY, - &local_model, - &endpoint, - flags.clone(), - None, // multi-node config. vllm uses `ray`, see guide - ) - .await - { - Ok(x) => x, - Err(err) => { - anyhow::bail!("Failed starting vllm sub-process: {err}"); - } - }; - let cancel_token = cancel_token.clone(); + // + // Run in from an input + // - // Sub-process cleanup - extra = Some(Box::pin(async move { - stopper(cancel_token, child, py_script).await; - })); - EngineConfig::Dynamic - } - Output::Trtllm => { - if flags.base_gpu_id != 0 { - anyhow::bail!("TRTLLM does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead."); - } + dynamo_llm::entrypoint::input::run_input(in_opt, runtime, engine_config).await?; - // If `in=dyn` we want the trtllm subprocess to listen on that endpoint. - // If not, then the endpoint isn't exposed so we invent an internal one. - let endpoint = match &in_opt { - Input::Endpoint(path) => path.parse()?, - _ => internal_endpoint("trtllm"), - }; + // Allow engines to ask main thread to wait on an extra future. + // We use this to stop the vllm and sglang sub-process + if let Some(extra) = extra { + extra.await; + } - let (py_script, child) = match subprocess::start( - subprocess::trtllm::PY, - &local_model, - &endpoint, - flags.clone(), - None, // multi-node config. trtlllm uses `mpi`, see guide - ) - .await - { - Ok(x) => x, - Err(err) => { - anyhow::bail!("Failed starting trtllm sub-process: {err}"); - } - }; - let cancel_token = cancel_token.clone(); + Ok(()) +} - // Sub-process cleanup - extra = Some(Box::pin(async move { - stopper(cancel_token, child, py_script).await; - })); - EngineConfig::Dynamic - } +type ExtraFuture = Pin + Send>>; +/// Create the engine matching `out_opt` +/// Note validation happens in Flags::validate. In here assume everything is going to work. +async fn engine_for( + cancel_token: CancellationToken, + out_opt: Output, + flags: Flags, + local_model: LocalModel, +) -> anyhow::Result<(EngineConfig, Option)> { + match out_opt { + Output::Dynamic => Ok((EngineConfig::Dynamic(Box::new(local_model)), None)), + Output::EchoFull => Ok(( + EngineConfig::StaticFull { + model: Box::new(local_model), + engine: dynamo_llm::engines::make_engine_full(), + }, + None, + )), + Output::EchoCore => Ok(( + EngineConfig::StaticCore { + engine: dynamo_llm::engines::make_engine_core(), + model: Box::new(local_model), + }, + None, + )), + #[cfg(feature = "mistralrs")] + Output::MistralRs => Ok(( + EngineConfig::StaticFull { + engine: dynamo_engine_mistralrs::make_engine(&local_model).await?, + model: Box::new(local_model), + }, + None, + )), #[cfg(feature = "llamacpp")] - Output::LlamaCpp => { - if !local_model.path().is_file() { - anyhow::bail!("--model-path should refer to a GGUF file. llama_cpp does not support safetensors."); - } - let engine = - dynamo_engine_llamacpp::make_engine(cancel_token.clone(), &local_model).await?; + Output::LlamaCpp => Ok(( EngineConfig::StaticCore { - engine, + engine: dynamo_engine_llamacpp::make_engine(cancel_token, &local_model).await?, model: Box::new(local_model), - } - } - }; - - match in_opt { - Input::Http => { - crate::input::http::run(runtime.clone(), flags, engine_config, template).await?; - } - Input::Text => { - crate::input::text::run(runtime.clone(), flags, None, engine_config, template).await?; - } - Input::Stdin => { - let mut prompt = String::new(); - std::io::stdin().read_to_string(&mut prompt).unwrap(); - crate::input::text::run( - runtime.clone(), + }, + None, + )), + // For multi-node config. vllm uses `ray`, see guide + Output::Vllm => shell(subprocess::vllm::PY, cancel_token, local_model, flags, None).await, + // For multi-node config. trtlllm uses `mpi`, see guide + Output::Trtllm => { + shell( + subprocess::trtllm::PY, + cancel_token, + local_model, flags, - Some(prompt), - engine_config, - template, + None, ) - .await?; - } - Input::Batch(path) => { - crate::input::batch::run(runtime.clone(), flags, card, path, engine_config, template) - .await?; + .await } - Input::Endpoint(path) => { - let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; - crate::input::endpoint::run(distributed_runtime, path, engine_config).await?; + Output::SgLang => { + let multi_node_config = if flags.num_nodes > 1 { + Some(dynamo_llm::engines::MultiNodeConfig { + num_nodes: flags.num_nodes, + node_rank: flags.node_rank, + leader_addr: flags.leader_addr.clone().unwrap_or_default(), + }) + } else { + None + }; + shell( + subprocess::sglang::PY, + cancel_token, + local_model, + flags, + multi_node_config, + ) + .await } } +} - // Allow engines to ask main thread to wait on an extra future. - // We use this to stop the vllm and sglang sub-process - if let Some(extra) = extra { - extra.await; - } +async fn shell( + py_script: &'static str, + cancel_token: CancellationToken, + local_model: LocalModel, + flags: Flags, + multi_node_config: Option, +) -> anyhow::Result<(EngineConfig, Option)> { + let (py_script, child) = + match subprocess::start(py_script, &local_model, flags.clone(), multi_node_config).await { + Ok(x) => x, + Err(err) => { + anyhow::bail!("Failed starting engine sub-process: {err}"); + } + }; - Ok(()) + // Sub-process cleanup + let extra: ExtraFuture = Box::pin(async move { + stopper(cancel_token, child, py_script).await; + }); + Ok((EngineConfig::Dynamic(Box::new(local_model)), Some(extra))) } /// Wait for cancel_token to be cancelled, then stop the child as gracefully as possible. @@ -341,21 +197,21 @@ async fn stopper( tokio::select! { exit = child.wait() => { - tracing::trace!("vllm sub-process graceful exit"); + tracing::trace!("engine sub-process graceful exit"); match exit { Ok(exit_status) if exit_status.success() => {} Ok(exit_status) => { // This is nearly always 15 (SIGTERM) - tracing::trace!("vllm sub-process non-0 exit: {exit_status}"); + tracing::trace!("engine sub-process non-0 exit: {exit_status}"); } Err(err) => { - tracing::warn!("vllm sub-process error getting exit status: {err}"); + tracing::warn!("engine sub-process error getting exit status: {err}"); } } } _ = tokio::time::sleep(CHILD_STOP_TIMEOUT) => { // It didn't stop in time, kill it - child.kill().await.expect("Failed killing vllm subprocess"); + child.kill().await.expect("Failed killing engine subprocess"); let _ = child.wait().await; } } @@ -400,6 +256,19 @@ fn print_cuda(output: &Output) { #[cfg(not(any(feature = "mistralrs", feature = "llamacpp")))] fn print_cuda(_output: &Output) {} +fn default_engine_for(local_model: &LocalModel) -> Output { + let default_engine = if local_model.card().is_gguf() { + gguf_default() + } else { + safetensors_default() + }; + tracing::info!( + "Using default engine: {default_engine}. Use out= to specify one of {}", + Output::available_engines().join(", ") + ); + default_engine +} + fn gguf_default() -> Output { #[cfg(feature = "llamacpp")] { @@ -428,13 +297,3 @@ fn safetensors_default() -> Output { Output::EchoFull } } - -/// A random endpoint to use for internal communication -/// We can't hard code because we may be running several on the same machine (GPUs 0-3 and 4-7) -fn internal_endpoint(engine: &str) -> EndpointId { - EndpointId { - namespace: Slug::slugify(&uuid::Uuid::new_v4().to_string()).to_string(), - component: engine.to_string(), - name: "generate".to_string(), - } -} diff --git a/launch/dynamo-run/src/main.rs b/launch/dynamo-run/src/main.rs index 46484bdb11..50670517fb 100644 --- a/launch/dynamo-run/src/main.rs +++ b/launch/dynamo-run/src/main.rs @@ -17,7 +17,8 @@ use std::env; use clap::Parser; -use dynamo_run::{Input, Output}; +use dynamo_llm::entrypoint::input::Input; +use dynamo_run::Output; use dynamo_runtime::logging; const HELP: &str = r#" @@ -127,5 +128,17 @@ async fn wrapper(runtime: dynamo_runtime::Runtime) -> anyhow::Result<()> { .chain(env::args().skip(non_flag_params)), )?; + if is_in_dynamic(&in_opt) && is_out_dynamic(&out_opt) { + anyhow::bail!("Cannot use endpoint for both in and out"); + } + dynamo_run::run(runtime, in_opt, out_opt, flags).await } + +fn is_in_dynamic(in_opt: &Input) -> bool { + matches!(in_opt, Input::Endpoint(_)) +} + +fn is_out_dynamic(out_opt: &Option) -> bool { + matches!(out_opt, Some(Output::Dynamic)) +} diff --git a/launch/dynamo-run/src/opt.rs b/launch/dynamo-run/src/opt.rs index 25ab953eb8..9b62f9e466 100644 --- a/launch/dynamo-run/src/opt.rs +++ b/launch/dynamo-run/src/opt.rs @@ -1,84 +1,8 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{fmt, io::IsTerminal as _, path::PathBuf}; use dynamo_runtime::protocols::ENDPOINT_SCHEME; - -const BATCH_PREFIX: &str = "batch:"; - -#[derive(PartialEq)] -pub enum Input { - /// Run an OpenAI compatible HTTP server - Http, - - /// Single prompt on stdin - Stdin, - - /// Interactive chat - Text, - - /// Pull requests from a namespace/component/endpoint path. - Endpoint(String), - - /// Batch mode. Run all the prompts, write the outputs, exit. - Batch(PathBuf), -} - -impl TryFrom<&str> for Input { - type Error = anyhow::Error; - - fn try_from(s: &str) -> anyhow::Result { - match s { - "http" => Ok(Input::Http), - "text" => Ok(Input::Text), - "stdin" => Ok(Input::Stdin), - endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => { - Ok(Input::Endpoint(endpoint_path.to_string())) - } - batch_patch if batch_patch.starts_with(BATCH_PREFIX) => { - let path = batch_patch.strip_prefix(BATCH_PREFIX).unwrap(); - Ok(Input::Batch(PathBuf::from(path))) - } - e => Err(anyhow::anyhow!("Invalid in= option '{e}'")), - } - } -} - -impl fmt::Display for Input { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let s = match self { - Input::Http => "http", - Input::Text => "text", - Input::Stdin => "stdin", - Input::Endpoint(path) => path, - Input::Batch(path) => &path.display().to_string(), - }; - write!(f, "{s}") - } -} - -impl Default for Input { - fn default() -> Self { - if std::io::stdin().is_terminal() { - Input::Text - } else { - Input::Stdin - } - } -} +use std::fmt; pub enum Output { /// Accept un-preprocessed requests, echo the prompt back as the response diff --git a/launch/dynamo-run/src/subprocess.rs b/launch/dynamo-run/src/subprocess.rs index 4b88c664dd..73ed24f7bd 100644 --- a/launch/dynamo-run/src/subprocess.rs +++ b/launch/dynamo-run/src/subprocess.rs @@ -13,7 +13,6 @@ use tokio::io::AsyncBufReadExt; use crate::flags::RouterMode; use dynamo_llm::engines::MultiNodeConfig; use dynamo_llm::local_model::LocalModel; -use dynamo_runtime::protocols::Endpoint as EndpointId; pub mod sglang; pub mod trtllm; @@ -24,8 +23,6 @@ pub async fn start( py_script: &'static str, // Model info local_model: &LocalModel, - // Endpoint to connect the subprocess over etcd/nats - endpoint: &EndpointId, // Command line flags for user overrides flags: super::Flags, // sglang multi-node config. vllm uses `ray` externally @@ -40,7 +37,7 @@ pub async fn start( let mut args = vec![ script_path.to_string_lossy().to_string(), "--endpoint".to_string(), - endpoint.as_url(), + local_model.endpoint_id().as_url(), "--model-path".to_string(), local_model.path().to_string_lossy().to_string(), "--model-name".to_string(), diff --git a/launch/llmctl/src/main.rs b/launch/llmctl/src/main.rs index f004c001ca..f8b96d8427 100644 --- a/launch/llmctl/src/main.rs +++ b/launch/llmctl/src/main.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use clap::{Parser, Subcommand}; use dynamo_llm::discovery::{ModelManager, ModelWatcher}; -use dynamo_llm::local_model::{LocalModel, ModelNetworkName}; +use dynamo_llm::local_model::{LocalModelBuilder, ModelNetworkName}; use dynamo_llm::model_type::ModelType; use dynamo_runtime::component::Endpoint; use dynamo_runtime::pipeline::RouterMode; @@ -227,7 +227,10 @@ async fn add_model( let endpoint = endpoint_from_name(distributed, &namespace, endpoint_name)?; - let mut model = LocalModel::with_name_only(&model_name); + let mut model = LocalModelBuilder::default() + .model_name(Some(model_name)) + .build() + .await?; model.attach(&endpoint, model_type).await?; Ok(()) diff --git a/lib/bindings/c/src/lib.rs b/lib/bindings/c/src/lib.rs index 1c50f4aa8e..62c22d6325 100644 --- a/lib/bindings/c/src/lib.rs +++ b/lib/bindings/c/src/lib.rs @@ -96,7 +96,7 @@ pub unsafe extern "C" fn dynamo_llm_init( match result { Ok(_) => match KV_PUB.get_or_try_init(move || { - dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size as usize) + dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size) }) { Ok(_) => DynamoLlmResult::OK, Err(e) => { @@ -139,7 +139,7 @@ fn dynamo_create_kv_publisher( namespace: String, component: String, worker_id: i64, - kv_block_size: usize, + kv_block_size: u32, ) -> Result { tracing::info!("Creating KV Publisher for model: {}", component); match DRT @@ -158,7 +158,7 @@ fn kv_event_create_stored_block_from_parts( block_hash: u64, token_ids: *const u32, num_tokens: usize, - kv_block_size: usize, + kv_block_size: u32, _lora_id: u64, ) -> KvCacheStoredBlockData { let tokens_hash = compute_block_hash_for_seq( @@ -174,7 +174,7 @@ static WARN_COUNT: AtomicU32 = AtomicU32::new(0); fn kv_event_create_stored_from_parts( kv_params: DynamoKvStoredEventParams, - kv_block_size: usize, + kv_block_size: u32, ) -> KvCacheEvent { let mut blocks: Vec = Vec::new(); @@ -188,7 +188,7 @@ fn kv_event_create_stored_from_parts( .offset(block_idx.try_into().unwrap()) }; - if num_toks != kv_block_size { + if num_toks != (kv_block_size as usize) { if WARN_COUNT .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| { if c < 3 { diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index 4024db5235..f08ff7f2e2 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -1011,6 +1011,18 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "dialoguer" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" +dependencies = [ + "console", + "shell-words", + "tempfile", + "thiserror 1.0.69", +] + [[package]] name = "digest" version = "0.10.7" @@ -1123,6 +1135,7 @@ dependencies = [ "cudarc", "derive-getters", "derive_builder", + "dialoguer", "dynamo-runtime", "either", "erased-serde", @@ -1131,6 +1144,7 @@ dependencies = [ "galil-seiferas", "ggus", "hf-hub", + "humantime", "itertools 0.14.0", "memmap2", "minijinja", @@ -4224,6 +4238,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shell-words" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde" + [[package]] name = "shlex" version = "1.3.0" diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index fa03c85534..f759157b5b 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -9,6 +9,7 @@ use pyo3::types::{PyDict, PyList, PyString}; use pyo3::IntoPyObjectExt; use pyo3::{exceptions::PyException, prelude::*}; use rs::pipeline::network::Ingress; +use std::path::PathBuf; use std::{fmt::Display, sync::Arc}; use tokio::sync::Mutex; @@ -104,8 +105,8 @@ fn register_llm<'p>( endpoint: Endpoint, model_path: &str, model_name: Option<&str>, - context_length: Option, - kv_cache_block_size: Option, + context_length: Option, + kv_cache_block_size: Option, ) -> PyResult> { let model_type_obj = match model_type { ModelType::Chat => llm_rs::model_type::ModelType::Chat, @@ -117,18 +118,14 @@ fn register_llm<'p>( let inner_path = model_path.to_string(); let model_name = model_name.map(|n| n.to_string()); pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut builder = dynamo_llm::local_model::LocalModelBuilder::default(); + builder + .model_path(Some(PathBuf::from(inner_path))) + .model_name(model_name) + .context_length(context_length) + .kv_cache_block_size(kv_cache_block_size); // Download from HF, load the ModelDeploymentCard - let mut local_model = - llm_rs::local_model::LocalModel::prepare(&inner_path, None, model_name) - .await - .map_err(to_pyerr)?; - if let Some(context_length) = context_length { - local_model.set_context_length(context_length); - } - if let Some(kv_cache_block_size) = kv_cache_block_size { - local_model.set_kv_cache_block_size(kv_cache_block_size); - } - + let mut local_model = builder.build().await.map_err(to_pyerr)?; // Advertise ourself on etcd so ingress can find us local_model .attach(&endpoint.inner, model_type_obj) diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index 53da276b26..3a9435b89a 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -40,10 +40,13 @@ impl KvRouter { let runtime = pyo3_async_runtimes::tokio::get_runtime(); runtime.block_on(async { - let inner = - llm_rs::kv_router::KvRouter::new(component.inner.clone(), kv_block_size, None) - .await - .map_err(to_pyerr)?; + let inner = llm_rs::kv_router::KvRouter::new( + component.inner.clone(), + kv_block_size as u32, + None, + ) + .await + .map_err(to_pyerr)?; Ok(Self { inner: Arc::new(inner), }) @@ -73,7 +76,7 @@ pub fn compute_block_hash_for_seq_py(tokens: Vec, kv_block_size: usize) -> return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0"))); } - let hashes = compute_block_hash_for_seq(&tokens, kv_block_size); + let hashes = compute_block_hash_for_seq(&tokens, kv_block_size as u32); Ok(hashes.into_iter().map(|h| h.0).collect()) } @@ -191,7 +194,7 @@ impl ZmqKvEventPublisher { let inner = llm_rs::kv_router::publisher::KvEventPublisher::new( component.inner, config.worker_id, - config.kv_block_size, + config.kv_block_size as u32, Some(KvEventSourceConfig::Zmq { endpoint: config.zmq_endpoint, topic: config.zmq_topic, @@ -232,7 +235,7 @@ impl ZmqKvEventListener { zmq_topic, tx, shutdown_token.clone(), - kv_block_size, + kv_block_size as u32, )); Ok(Self { @@ -293,7 +296,7 @@ impl KvEventPublisher { let inner = llm_rs::kv_router::publisher::KvEventPublisher::new( component.inner, worker_id, - kv_block_size, + kv_block_size as u32, None, ) .map_err(to_pyerr)?; @@ -322,7 +325,7 @@ impl KvEventPublisher { data: KvCacheEventData::Stored(KvCacheStoreData { parent_hash: parent_hash.map(ExternalSequenceBlockHash::from), blocks: create_stored_blocks( - self.kv_block_size, + self.kv_block_size as u32, &token_ids, &num_block_tokens, &block_hashes, @@ -446,7 +449,7 @@ impl KvIndexer { let inner: Arc = llm_rs::kv_router::indexer::KvIndexer::new( component.inner.drt().runtime().child_token(), - kv_block_size, + kv_block_size as u32, ) .into(); // [gluo TODO] try subscribe_with_type::, @@ -478,7 +481,7 @@ impl KvIndexer { } fn block_size(&self) -> usize { - self.inner.block_size() + self.inner.block_size() as usize } fn find_matches<'p>(&self, py: Python<'p>, sequence: Vec) -> PyResult> { diff --git a/lib/engines/llamacpp/src/lib.rs b/lib/engines/llamacpp/src/lib.rs index 768421e610..7e4b73d841 100644 --- a/lib/engines/llamacpp/src/lib.rs +++ b/lib/engines/llamacpp/src/lib.rs @@ -78,7 +78,7 @@ impl LlamacppEngine { let (ctx_set, ctx_get) = tokio::sync::mpsc::channel(NUM_CONTEXTS); let llama_ctx_params = if model_config.card().context_length > 0 { - let n_ctx = NonZeroU32::new(model_config.card().context_length as u32); + let n_ctx = NonZeroU32::new(model_config.card().context_length); LlamaContextParams::default().with_n_ctx(n_ctx) } else { // Context length defaults to 512 currently diff --git a/lib/engines/mistralrs/src/lib.rs b/lib/engines/mistralrs/src/lib.rs index 44c0d668db..80a3fbd1a8 100644 --- a/lib/engines/mistralrs/src/lib.rs +++ b/lib/engines/mistralrs/src/lib.rs @@ -128,7 +128,7 @@ impl MistralRsEngine { .build(None)? }; - let mut max_seq_len = model.card().context_length; + let mut max_seq_len = model.card().context_length as usize; if max_seq_len == 0 { tracing::info!("context_length is 0. Probably error reading from model."); max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN; diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index 05f5588ee6..cb696faff3 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -53,6 +53,7 @@ either = { workspace = true } etcd-client = { workspace = true } futures = { workspace = true } hf-hub = { workspace = true } +humantime = { workspace = true } # input/batch rand = { workspace = true } oneshot = { workspace = true } prometheus = { workspace = true } @@ -80,6 +81,9 @@ offset-allocator = "0.2" regex = "1" rayon = "1" +# input/text +dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] } + # block_manager nixl-sys = {git="https://github.com/ai-dynamo/nixl", rev = "a7c654d46a14cd5ce635cc8c02433d71df93dedf", optional = true } cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true } diff --git a/lib/llm/src/block_manager/block.rs b/lib/llm/src/block_manager/block.rs index db8ca4c4f9..5e8eec50f3 100644 --- a/lib/llm/src/block_manager/block.rs +++ b/lib/llm/src/block_manager/block.rs @@ -1605,7 +1605,7 @@ mod tests { use dynamo_runtime::logging::init as init_logging; use nixl_sys::Agent as NixlAgent; - const BLOCK_SIZE: usize = 4; + const BLOCK_SIZE: u32 = 4; const SALT_HASH: SaltHash = 12345; // Helper to create a default reset block @@ -1666,7 +1666,7 @@ mod tests { // Extend to fill capacity assert!(block.add_tokens(Tokens::from(vec![4])).is_ok()); // 1, 2, 3, 4 - assert_eq!(block.len(), BLOCK_SIZE); + assert_eq!(block.len(), BLOCK_SIZE as usize); // Append when full (should fail) assert!(block.add_token(5).is_err(), "Append on full Partial block"); @@ -1690,7 +1690,7 @@ mod tests { // Fill block again for commit assert!(block.add_tokens(Tokens::from(vec![1, 2, 3, 4])).is_ok()); - assert_eq!(block.len(), BLOCK_SIZE); + assert_eq!(block.len(), BLOCK_SIZE as usize); // --- Partial -> Complete (via commit) --- // assert!(block.commit().is_ok()); diff --git a/lib/llm/src/block_manager/block/state.rs b/lib/llm/src/block_manager/block/state.rs index 5eb6ff5bff..b5c41a87c8 100644 --- a/lib/llm/src/block_manager/block/state.rs +++ b/lib/llm/src/block_manager/block/state.rs @@ -43,7 +43,7 @@ impl BlockState { return Err(BlockStateInvalid("Block is not reset".to_string())); } - let block = PartialTokenBlock::create_sequence_root(page_size, salt_hash); + let block = PartialTokenBlock::create_sequence_root(page_size as u32, salt_hash); *self = BlockState::Partial(PartialState::new(block)); Ok(()) } diff --git a/lib/llm/src/block_manager/pool/inactive.rs b/lib/llm/src/block_manager/pool/inactive.rs index fe8b84fb28..9b695fa35a 100644 --- a/lib/llm/src/block_manager/pool/inactive.rs +++ b/lib/llm/src/block_manager/pool/inactive.rs @@ -648,7 +648,7 @@ pub(crate) mod tests { /// Each block is initialized to the Complete state and then Registered. pub fn create_blocks( tokens: Tokens, - block_size: usize, + block_size: u32, async_runtime: Handle, ) -> Vec> { let (token_blocks, _partial_token_block) = @@ -691,7 +691,7 @@ pub(crate) mod tests { pub fn acquire_blocks( tokens: Tokens, - block_size: usize, + block_size: u32, pool: &mut InactiveBlockPool, async_runtime: Handle, ) -> (Vec>, usize) { @@ -749,7 +749,7 @@ pub(crate) mod tests { let async_runtime = tokio::runtime::Runtime::new().unwrap(); - const PAGE_SIZE: usize = 2; + const PAGE_SIZE: u32 = 2; let mut pool = create_block_pool(10); assert_eq!(pool.total_blocks(), 10); diff --git a/lib/llm/src/discovery/model_manager.rs b/lib/llm/src/discovery/model_manager.rs index 0341026d63..550c43a90f 100644 --- a/lib/llm/src/discovery/model_manager.rs +++ b/lib/llm/src/discovery/model_manager.rs @@ -180,7 +180,7 @@ impl ModelManager { &self, model_name: &str, component: &Component, - kv_cache_block_size: usize, + kv_cache_block_size: u32, kv_router_config: Option, ) -> anyhow::Result> { if let Some(kv_chooser) = self.get_kv_chooser(model_name) { @@ -209,7 +209,7 @@ impl ModelManager { &self, model_name: &str, component: &Component, - kv_cache_block_size: usize, + kv_cache_block_size: u32, kv_router_config: Option, ) -> anyhow::Result> { let selector = Box::new(DefaultWorkerSelector::new(kv_router_config)); diff --git a/lib/llm/src/entrypoint.rs b/lib/llm/src/entrypoint.rs new file mode 100644 index 0000000000..6881b10ffc --- /dev/null +++ b/lib/llm/src/entrypoint.rs @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! The entrypoint module provides tools to build a Dynamo runner. +//! - Create an EngineConfig of the engine (potentially auto-discovered) to execute +//! - Connect it to an Input + +pub mod input; + +use std::sync::Arc; + +use dynamo_runtime::pipeline::RouterMode; + +use crate::{ + backend::ExecutionContext, engines::StreamingEngine, kv_router::KvRouterConfig, + local_model::LocalModel, +}; + +#[derive(Debug, Clone, Default)] +pub struct RouterConfig { + pub router_mode: RouterMode, + pub kv_router_config: KvRouterConfig, +} + +impl RouterConfig { + pub fn new(router_mode: RouterMode, kv_router_config: KvRouterConfig) -> Self { + Self { + router_mode, + kv_router_config, + } + } +} + +pub enum EngineConfig { + /// Remote networked engines + Dynamic(Box), + + /// A Full service engine does it's own tokenization and prompt formatting. + StaticFull { + engine: Arc, + model: Box, + }, + + /// A core engine expects to be wrapped with pre/post processors that handle tokenization. + StaticCore { + engine: ExecutionContext, + model: Box, + }, +} + +impl EngineConfig { + fn local_model(&self) -> &LocalModel { + use EngineConfig::*; + match self { + Dynamic(lm) => lm, + StaticFull { model, .. } => model, + StaticCore { model, .. } => model, + } + } +} diff --git a/lib/llm/src/entrypoint/input.rs b/lib/llm/src/entrypoint/input.rs new file mode 100644 index 0000000000..19bcf59dbc --- /dev/null +++ b/lib/llm/src/entrypoint/input.rs @@ -0,0 +1,116 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! This module contains tools to gather a prompt from a user, forward it to an engine and return +//! the response. +//! See the Input enum for the inputs available. Input::Http (OpenAI compatible HTTP server) +//! and Input::Text (interactive chat) are good places to start. +//! The main entry point is `run_input`. + +use std::{ + fmt, + io::{IsTerminal as _, Read as _}, + path::PathBuf, +}; + +pub mod batch; +mod common; +pub mod endpoint; +pub mod http; +pub mod text; + +use dynamo_runtime::{protocols::ENDPOINT_SCHEME, DistributedRuntime}; + +const BATCH_PREFIX: &str = "batch:"; + +/// The various ways of connecting prompts to an engine +#[derive(PartialEq)] +pub enum Input { + /// Run an OpenAI compatible HTTP server + Http, + + /// Single prompt on stdin + Stdin, + + /// Interactive chat + Text, + + /// Pull requests from a namespace/component/endpoint path. + Endpoint(String), + + /// Batch mode. Run all the prompts, write the outputs, exit. + Batch(PathBuf), +} + +impl TryFrom<&str> for Input { + type Error = anyhow::Error; + + fn try_from(s: &str) -> anyhow::Result { + match s { + "http" => Ok(Input::Http), + "text" => Ok(Input::Text), + "stdin" => Ok(Input::Stdin), + endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => { + Ok(Input::Endpoint(endpoint_path.to_string())) + } + batch_patch if batch_patch.starts_with(BATCH_PREFIX) => { + let path = batch_patch.strip_prefix(BATCH_PREFIX).unwrap(); + Ok(Input::Batch(PathBuf::from(path))) + } + e => Err(anyhow::anyhow!("Invalid in= option '{e}'")), + } + } +} + +impl fmt::Display for Input { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let s = match self { + Input::Http => "http", + Input::Text => "text", + Input::Stdin => "stdin", + Input::Endpoint(path) => path, + Input::Batch(path) => &path.display().to_string(), + }; + write!(f, "{s}") + } +} + +impl Default for Input { + fn default() -> Self { + if std::io::stdin().is_terminal() { + Input::Text + } else { + Input::Stdin + } + } +} + +/// Run the given engine (EngineConfig) connected to an input. +/// Does not return until the input exits. +pub async fn run_input( + in_opt: Input, + runtime: dynamo_runtime::Runtime, + engine_config: super::EngineConfig, +) -> anyhow::Result<()> { + match in_opt { + Input::Http => { + http::run(runtime.clone(), engine_config).await?; + } + Input::Text => { + text::run(runtime.clone(), None, engine_config).await?; + } + Input::Stdin => { + let mut prompt = String::new(); + std::io::stdin().read_to_string(&mut prompt).unwrap(); + text::run(runtime.clone(), Some(prompt), engine_config).await?; + } + Input::Batch(path) => { + batch::run(runtime.clone(), path, engine_config).await?; + } + Input::Endpoint(path) => { + let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; + endpoint::run(distributed_runtime, path, engine_config).await?; + } + } + Ok(()) +} diff --git a/launch/dynamo-run/src/input/batch.rs b/lib/llm/src/entrypoint/input/batch.rs similarity index 94% rename from launch/dynamo-run/src/input/batch.rs rename to lib/llm/src/entrypoint/input/batch.rs index c236aebacb..ede518f1ee 100644 --- a/launch/dynamo-run/src/input/batch.rs +++ b/lib/llm/src/entrypoint/input/batch.rs @@ -1,14 +1,13 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use anyhow::Context as _; -use async_openai::types::FinishReason; -use dynamo_llm::model_card::model::ModelDeploymentCard; -use dynamo_llm::preprocessor::OpenAIPreprocessor; -use dynamo_llm::request_template::RequestTemplate; -use dynamo_llm::types::openai::chat_completions::{ +use crate::preprocessor::OpenAIPreprocessor; +use crate::request_template::RequestTemplate; +use crate::types::openai::chat_completions::{ NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine, }; +use anyhow::Context as _; +use async_openai::types::FinishReason; use dynamo_runtime::{pipeline::Context, runtime::CancellationToken, Runtime}; use futures::StreamExt; use serde::{Deserialize, Serialize}; @@ -19,8 +18,8 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; -use crate::input::common; -use crate::{EngineConfig, Flags}; +use crate::entrypoint::input::common; +use crate::entrypoint::EngineConfig; /// Max tokens in each response. /// TODO: For batch mode this should be the full context size of the model @@ -53,11 +52,8 @@ struct Entry { pub async fn run( runtime: Runtime, - _flags: Flags, - card: ModelDeploymentCard, input_jsonl: PathBuf, engine_config: EngineConfig, - template: Option, ) -> anyhow::Result<()> { let cancel_token = runtime.primary_token(); // Check if the path exists and is a directory @@ -68,11 +64,10 @@ pub async fn run( ); } - let prepared_engine = common::prepare_engine(runtime, engine_config).await?; - let service_name_ref = Arc::new(prepared_engine.service_name); + let mut prepared_engine = common::prepare_engine(runtime, engine_config).await?; - let pre_processor = if card.has_tokenizer() { - Some(OpenAIPreprocessor::new(card).await?) + let pre_processor = if prepared_engine.has_tokenizer() { + Some(OpenAIPreprocessor::new(prepared_engine.card.take().unwrap()).await?) } else { None }; @@ -85,6 +80,7 @@ pub async fn run( tracing::error!(%err, "Failed writing output to {}", output_file.display()); } }); + let service_name_ref = Arc::new(prepared_engine.service_name); let tokens_in = Arc::new(AtomicU64::new(0)); let tokens_out = Arc::new(AtomicU64::new(0)); @@ -98,7 +94,7 @@ pub async fn run( tracing::info!("Timer start."); let start = Instant::now(); let mut lines = buffered_input.lines(); - let template: Option> = template.map(Arc::new); + let template: Option> = prepared_engine.request_template.map(Arc::new); while let Ok(Some(line)) = lines.next_line().await { if cancel_token.is_cancelled() { break; diff --git a/launch/dynamo-run/src/input/common.rs b/lib/llm/src/entrypoint/input/common.rs similarity index 87% rename from launch/dynamo-run/src/input/common.rs rename to lib/llm/src/entrypoint/input/common.rs index c9f17b346a..8f46de078b 100644 --- a/launch/dynamo-run/src/input/common.rs +++ b/lib/llm/src/entrypoint/input/common.rs @@ -3,13 +3,15 @@ use std::pin::Pin; -use dynamo_llm::{ +use crate::{ backend::{Backend, ExecutionContext}, discovery::{ModelManager, ModelWatcher, MODEL_ROOT_PATH}, engines::StreamingEngineAdapter, + entrypoint::EngineConfig, model_card::ModelDeploymentCard, preprocessor::OpenAIPreprocessor, protocols::common::llm_backend::{BackendOutput, PreprocessedRequest}, + request_template::RequestTemplate, types::{ openai::chat_completions::{ NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, @@ -25,12 +27,22 @@ use dynamo_runtime::{ }; use std::sync::Arc; -use crate::EngineConfig; - pub struct PreparedEngine { pub service_name: String, pub engine: OpenAIChatCompletionsStreamingEngine, pub inspect_template: bool, + pub card: Option, + pub request_template: Option, +} + +impl PreparedEngine { + pub fn has_tokenizer(&self) -> bool { + if let Some(card) = self.card.as_ref() { + card.has_tokenizer() + } else { + false + } + } } /// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine. @@ -39,7 +51,7 @@ pub async fn prepare_engine( engine_config: EngineConfig, ) -> anyhow::Result { match engine_config { - EngineConfig::Dynamic => { + EngineConfig::Dynamic(local_model) => { let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; let Some(etcd_client) = distributed_runtime.etcd_client() else { @@ -71,6 +83,8 @@ pub async fn prepare_engine( service_name: model_service_name, engine, inspect_template: false, + card: None, + request_template: local_model.request_template(), }) } EngineConfig::StaticFull { engine, model } => { @@ -81,6 +95,8 @@ pub async fn prepare_engine( service_name, engine, inspect_template: false, + request_template: model.request_template(), + card: Some(model.into_card()), }) } EngineConfig::StaticCore { @@ -99,6 +115,8 @@ pub async fn prepare_engine( service_name, engine: pipeline, inspect_template: true, + request_template: model.request_template(), + card: Some(model.into_card()), }) } } @@ -137,21 +155,21 @@ where #[cfg(test)] mod tests { use super::*; - use dynamo_llm::types::openai::{ + use crate::types::openai::{ chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, }; const HF_PATH: &str = concat!( env!("CARGO_MANIFEST_DIR"), - "/../../lib/llm/tests/data/sample-models/mock-llama-3.1-8b-instruct" + "/tests/data/sample-models/mock-llama-3.1-8b-instruct" ); #[tokio::test] async fn test_build_chat_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> { // Create test model card let card = ModelDeploymentCard::load(HF_PATH).await?; - let engine = dynamo_llm::engines::make_engine_core(); + let engine = crate::engines::make_engine_core(); // Build pipeline for chat completions let pipeline = build_pipeline::< @@ -170,7 +188,7 @@ mod tests { async fn test_build_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> { // Create test model card let card = ModelDeploymentCard::load(HF_PATH).await?; - let engine = dynamo_llm::engines::make_engine_core(); + let engine = crate::engines::make_engine_core(); // Build pipeline for completions let pipeline = diff --git a/launch/dynamo-run/src/input/endpoint.rs b/lib/llm/src/entrypoint/input/endpoint.rs similarity index 86% rename from launch/dynamo-run/src/input/endpoint.rs rename to lib/llm/src/entrypoint/input/endpoint.rs index 582f313ae5..e787f2a13c 100644 --- a/launch/dynamo-run/src/input/endpoint.rs +++ b/lib/llm/src/entrypoint/input/endpoint.rs @@ -1,21 +1,9 @@ // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. use std::{future::Future, pin::Pin, sync::Arc}; -use dynamo_llm::{ +use crate::{ backend::Backend, engines::StreamingEngineAdapter, model_type::ModelType, @@ -33,7 +21,7 @@ use dynamo_runtime::pipeline::{ }; use dynamo_runtime::{protocols::Endpoint as EndpointId, DistributedRuntime}; -use crate::EngineConfig; +use crate::entrypoint::EngineConfig; pub async fn run( distributed_runtime: DistributedRuntime, @@ -91,7 +79,7 @@ pub async fn run( (Box::pin(fut), Some(model.card().clone())) } - EngineConfig::Dynamic => { + EngineConfig::Dynamic(_) => { // We can only get here for in=dyn out=vllm|sglang`, because vllm and sglang are a // subprocess that we talk to like a remote endpoint. // That means the vllm/sglang subprocess is doing all the work, we are idle. diff --git a/launch/dynamo-run/src/input/http.rs b/lib/llm/src/entrypoint/input/http.rs similarity index 87% rename from launch/dynamo-run/src/input/http.rs rename to lib/llm/src/entrypoint/input/http.rs index aaab80b467..fa229fff8f 100644 --- a/launch/dynamo-run/src/input/http.rs +++ b/lib/llm/src/entrypoint/input/http.rs @@ -3,14 +3,12 @@ use std::sync::Arc; -use crate::input::common; -use crate::{EngineConfig, Flags}; -use dynamo_llm::kv_router::KvRouterConfig; -use dynamo_llm::{ +use crate::{ discovery::{ModelManager, ModelWatcher, MODEL_ROOT_PATH}, engines::StreamingEngineAdapter, + entrypoint::{input::common, EngineConfig}, http::service::service_v2, - request_template::RequestTemplate, + kv_router::KvRouterConfig, types::{ openai::chat_completions::{ NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, @@ -23,32 +21,28 @@ use dynamo_runtime::transports::etcd; use dynamo_runtime::{DistributedRuntime, Runtime}; /// Build and run an HTTP service -pub async fn run( - runtime: Runtime, - flags: Flags, - engine_config: EngineConfig, - template: Option, -) -> anyhow::Result<()> { +pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> { let http_service = service_v2::HttpService::builder() - .port(flags.http_port) + .port(engine_config.local_model().http_port()) .enable_chat_endpoints(true) .enable_cmpl_endpoints(true) .enable_embeddings_endpoints(true) - .with_request_template(template) + .with_request_template(engine_config.local_model().request_template()) .build()?; match engine_config { - EngineConfig::Dynamic => { + EngineConfig::Dynamic(_) => { let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; match distributed_runtime.etcd_client() { Some(etcd_client) => { + let router_config = engine_config.local_model().router_config(); // Listen for models registering themselves in etcd, add them to HTTP service run_watcher( distributed_runtime, http_service.state().manager_clone(), etcd_client.clone(), MODEL_ROOT_PATH, - flags.router_mode.into(), - Some(flags.kv_router_config()), + router_config.router_mode, + Some(router_config.kv_router_config.clone()), ) .await?; } diff --git a/launch/dynamo-run/src/input/text.rs b/lib/llm/src/entrypoint/input/text.rs similarity index 95% rename from launch/dynamo-run/src/input/text.rs rename to lib/llm/src/entrypoint/input/text.rs index b09d041cd7..fb6d2f0116 100644 --- a/launch/dynamo-run/src/input/text.rs +++ b/lib/llm/src/entrypoint/input/text.rs @@ -1,16 +1,17 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use dynamo_llm::protocols::openai::nvext::NvExt; -use dynamo_llm::types::openai::chat_completions::{ +use crate::protocols::openai::nvext::NvExt; +use crate::request_template::RequestTemplate; +use crate::types::openai::chat_completions::{ NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine, }; use dynamo_runtime::{pipeline::Context, runtime::CancellationToken, Runtime}; use futures::StreamExt; use std::io::{ErrorKind, Write}; -use crate::input::common; -use crate::{EngineConfig, Flags, RequestTemplate}; +use crate::entrypoint::input::common; +use crate::entrypoint::EngineConfig; /// Max response tokens for each single query. Must be less than model context size. /// TODO: Cmd line flag to overwrite this @@ -18,20 +19,19 @@ const MAX_TOKENS: u32 = 8192; pub async fn run( runtime: Runtime, - _flags: Flags, single_prompt: Option, engine_config: EngineConfig, - template: Option, ) -> anyhow::Result<()> { let cancel_token = runtime.primary_token(); let prepared_engine = common::prepare_engine(runtime, engine_config).await?; + // TODO: Pass prepared_engine directly main_loop( cancel_token, &prepared_engine.service_name, prepared_engine.engine, single_prompt, prepared_engine.inspect_template, - template, + prepared_engine.request_template, ) .await } diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 0c732d869c..e2b9b98ba7 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -50,7 +50,7 @@ pub trait WorkerSelector { &self, workers: &ProcessedEndpoints, request: &SchedulingRequest, - block_size: usize, + block_size: u32, ) -> Result; } @@ -104,13 +104,13 @@ impl KvRouterConfig { pub struct KvRouter { indexer: KvIndexer, scheduler: KvScheduler, - block_size: usize, + block_size: u32, } impl KvRouter { pub async fn new( component: Component, - block_size: usize, + block_size: u32, selector: Option>, ) -> Result { let cancellation_token = component @@ -196,7 +196,7 @@ impl KvRouter { } /// Get the block size this router was configured with - pub fn block_size(&self) -> usize { + pub fn block_size(&self) -> u32 { self.block_size } } diff --git a/lib/llm/src/kv_router/indexer.rs b/lib/llm/src/kv_router/indexer.rs index 3080c2de08..382c8f08f6 100644 --- a/lib/llm/src/kv_router/indexer.rs +++ b/lib/llm/src/kv_router/indexer.rs @@ -119,9 +119,9 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash { /// ### Returns /// /// A vector of `LocalBlockHash` representing the computed hashes for each chunk of tokens. -pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: usize) -> Vec { +pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: u32) -> Vec { tokens - .chunks_exact(kv_block_size) // Split into chunks of kv_block_size elements + .chunks_exact(kv_block_size as usize) // Split into chunks of kv_block_size elements .map(|chunk| { let bytes: Vec = chunk .iter() @@ -527,7 +527,7 @@ pub struct KvIndexer { /// A handle to the background task managing the KV store. task: OnceLock>, /// The size of the KV block this indexer can handle. - kv_block_size: usize, + kv_block_size: u32, } impl KvIndexer { @@ -544,7 +544,7 @@ impl KvIndexer { pub fn new_with_frequency( token: CancellationToken, expiration_duration: Option, - kv_block_size: usize, + kv_block_size: u32, ) -> Self { let (event_tx, event_rx) = mpsc::channel::(2048); let (match_tx, match_rx) = mpsc::channel::(128); @@ -611,11 +611,11 @@ impl KvIndexer { } } - pub fn block_size(&self) -> usize { + pub fn block_size(&self) -> u32 { self.kv_block_size } - pub fn new(token: CancellationToken, kv_block_size: usize) -> Self { + pub fn new(token: CancellationToken, kv_block_size: u32) -> Self { Self::new_with_frequency(token, None, kv_block_size) } @@ -697,7 +697,7 @@ pub struct KvIndexerSharded { /// A `CancellationToken` for managing shutdown. cancel: CancellationToken, /// The size of the KV block this indexer can handle. - kv_block_size: usize, + kv_block_size: u32, worker_assignments: HashMap, worker_counts: Vec, @@ -723,7 +723,7 @@ impl KvIndexerSharded { token: CancellationToken, num_shards: usize, expiration_duration: Option, - kv_block_size: usize, + kv_block_size: u32, ) -> Self { let worker_assignments: HashMap = HashMap::new(); let worker_counts: Vec = vec![0; num_shards]; @@ -802,11 +802,11 @@ impl KvIndexerSharded { } } - pub fn block_size(&self) -> usize { + pub fn block_size(&self) -> u32 { self.kv_block_size } - pub fn new(token: CancellationToken, num_shards: usize, kv_block_size: usize) -> Self { + pub fn new(token: CancellationToken, num_shards: usize, kv_block_size: u32) -> Self { Self::new_with_frequency(token, num_shards, None, kv_block_size) } } @@ -1312,24 +1312,20 @@ mod tests { #[case(11)] #[case(32)] #[case(64)] - fn test_compute_block_hash_for_seq(#[case] kv_block_size: usize) { + fn test_compute_block_hash_for_seq(#[case] kv_block_size: u32) { setup(); // create a sequence of 64 elements - let sequence = (0..kv_block_size).map(|i| i as u32).collect::>(); + let sequence = (0..kv_block_size).collect::>(); let hashes = compute_block_hash_for_seq(&sequence, kv_block_size); assert_eq!(hashes.len(), 1); // create a sequence of 65 elements - let sequence = (0..(kv_block_size + 1)) - .map(|i| i as u32) - .collect::>(); + let sequence = (0..(kv_block_size + 1)).collect::>(); let hashes = compute_block_hash_for_seq(&sequence, kv_block_size); assert_eq!(hashes.len(), 1); // create a sequence of 129 elements - let sequence = (0..(2 * kv_block_size + 1)) - .map(|i| i as u32) - .collect::>(); + let sequence = (0..(2 * kv_block_size + 1)).collect::>(); let hashes = compute_block_hash_for_seq(&sequence, kv_block_size); assert_eq!(hashes.len(), 2); } @@ -1337,7 +1333,7 @@ mod tests { fn make_indexer( token: &CancellationToken, num_shards: usize, - kv_block_size: usize, + kv_block_size: u32, ) -> Box { if num_shards == 1 { Box::new(KvIndexer::new(token.clone(), kv_block_size)) @@ -1360,7 +1356,7 @@ mod tests { #[tokio::test] #[apply(indexer_template)] - async fn test_kv_indexer_new(num_shards: usize, kv_block_size: usize) { + async fn test_kv_indexer_new(num_shards: usize, kv_block_size: u32) { setup(); let token: CancellationToken = CancellationToken::new(); let _ = make_indexer(&token, num_shards, kv_block_size); @@ -1368,7 +1364,7 @@ mod tests { #[tokio::test] #[apply(indexer_template)] - async fn test_find_matches(num_shards: usize, kv_block_size: usize) { + async fn test_find_matches(num_shards: usize, kv_block_size: u32) { setup(); let token = CancellationToken::new(); let kv_indexer = make_indexer(&token, num_shards, kv_block_size); @@ -1381,7 +1377,7 @@ mod tests { #[tokio::test] #[apply(indexer_template)] - async fn test_find_matches_for_request(num_shards: usize, kv_block_size: usize) { + async fn test_find_matches_for_request(num_shards: usize, kv_block_size: u32) { setup(); let token = CancellationToken::new(); let kv_indexer = make_indexer(&token, num_shards, kv_block_size); @@ -1394,7 +1390,7 @@ mod tests { #[tokio::test] #[apply(indexer_template)] - async fn test_apply_event(num_shards: usize, kv_block_size: usize) { + async fn test_apply_event(num_shards: usize, kv_block_size: u32) { setup(); let worker_id = 0; @@ -1409,7 +1405,7 @@ mod tests { #[tokio::test] #[apply(indexer_template)] - async fn test_shutdown(num_shards: usize, kv_block_size: usize) { + async fn test_shutdown(num_shards: usize, kv_block_size: u32) { setup(); let token = CancellationToken::new(); let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size); @@ -1419,7 +1415,7 @@ mod tests { #[tokio::test] #[apply(indexer_template)] - async fn test_frequency(num_shards: usize, kv_block_size: usize) { + async fn test_frequency(num_shards: usize, kv_block_size: u32) { const ONE_MILLIS: Duration = Duration::from_millis(1); setup(); diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index 7f02b0bd8d..c72cc4b4af 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -62,7 +62,7 @@ impl KvEventSource { /// Start the event source from a [`KvEventSourceConfig`]. fn start( component: Component, - kv_block_size: usize, + kv_block_size: u32, source_config: KvEventSourceConfig, cancellation_token: CancellationToken, tx: mpsc::UnboundedSender, @@ -98,7 +98,7 @@ impl KvEventSource { /// A publisher of KV events. pub struct KvEventPublisher { /// The size of the KV block. - kv_block_size: usize, + kv_block_size: u32, /// The source of KV events. /// Can be `None` if all events provided through [`KvEventPublisher::publish`]. source: Option, @@ -112,7 +112,7 @@ impl KvEventPublisher { pub fn new( component: Component, worker_id: i64, - kv_block_size: usize, + kv_block_size: u32, source_config: Option, ) -> Result { let cancellation_token = CancellationToken::new(); @@ -155,7 +155,7 @@ impl KvEventPublisher { self.tx.send(event) } - pub fn kv_block_size(&self) -> usize { + pub fn kv_block_size(&self) -> u32 { self.kv_block_size } @@ -223,7 +223,7 @@ pub async fn start_zmq_listener( zmq_topic: String, tx: mpsc::UnboundedSender, cancellation_token: CancellationToken, - kv_block_size: usize, + kv_block_size: u32, ) { tracing::debug!( "KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')", @@ -335,7 +335,7 @@ pub async fn start_zmq_listener( fn convert_event( raw: RawKvEvent, event_id: u64, - kv_block_size: usize, + kv_block_size: u32, warning_count: &Arc, ) -> KvCacheEvent { match raw { @@ -382,7 +382,7 @@ fn convert_event( } pub fn create_stored_block_from_parts( - kv_block_size: usize, + kv_block_size: u32, block_hash: i64, token_ids: &[u32], _lora_id: u64, @@ -395,7 +395,7 @@ pub fn create_stored_block_from_parts( } pub fn create_stored_blocks( - kv_block_size: usize, + kv_block_size: u32, token_ids: &[u32], num_block_tokens: &[u64], block_hashes: &[i64], diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index f67dfa6127..0f893bcdba 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -92,7 +92,7 @@ pub struct KvScheduler { impl KvScheduler { pub async fn start( ns: Namespace, - block_size: usize, + block_size: u32, endpoints_rx: tokio::sync::watch::Receiver, selector: Option>, ) -> Result { @@ -299,7 +299,7 @@ impl WorkerSelector for DefaultWorkerSelector { &self, workers: &ProcessedEndpoints, request: &SchedulingRequest, - block_size: usize, + block_size: u32, ) -> Result { assert!(request.isl_tokens > 0); @@ -307,7 +307,7 @@ impl WorkerSelector for DefaultWorkerSelector { return Err(KvSchedulerError::NoEndpoints); } - let request_blocks = request.isl_tokens.div_ceil(block_size); + let request_blocks = request.isl_tokens.div_ceil(block_size as usize); let mut worker_logits = HashMap::new(); // Calculate logits for each worker diff --git a/lib/llm/src/lib.rs b/lib/llm/src/lib.rs index 871aa9543f..5027810f29 100644 --- a/lib/llm/src/lib.rs +++ b/lib/llm/src/lib.rs @@ -15,6 +15,7 @@ pub mod common; pub mod disagg_router; pub mod discovery; pub mod engines; +pub mod entrypoint; pub mod gguf; pub mod http; pub mod hub; diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 31e3d6ca71..f404a16524 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -5,6 +5,9 @@ use std::fs; use std::path::{Path, PathBuf}; use std::sync::Arc; +use anyhow::Context as _; +use dynamo_runtime::protocols::Endpoint as EndpointId; +use dynamo_runtime::slug::Slug; use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::{ component::{Component, Endpoint}, @@ -12,8 +15,10 @@ use dynamo_runtime::{ }; use crate::discovery::ModelEntry; +use crate::entrypoint::RouterConfig; use crate::model_card::{self, ModelDeploymentCard}; use crate::model_type::ModelType; +use crate::request_template::RequestTemplate; mod network_name; pub use network_name::ModelNetworkName; @@ -25,58 +30,85 @@ const HF_SCHEME: &str = "hf://"; /// is invisible, for example in a text chat. const DEFAULT_NAME: &str = "dynamo"; -#[derive(Debug, Clone)] -pub struct LocalModel { - full_path: PathBuf, - card: ModelDeploymentCard, +/// Engines don't usually provide a default, so we do. +const DEFAULT_KV_CACHE_BLOCK_SIZE: u32 = 16; + +/// We can't have it default to 0, so pick something +const DEFAULT_HTTP_PORT: u16 = 8080; + +pub struct LocalModelBuilder { + model_path: Option, + model_name: Option, + model_config: Option, + endpoint_id: Option, + context_length: Option, + template_file: Option, + router_config: Option, + kv_cache_block_size: u32, + http_port: u16, } -impl Default for LocalModel { +impl Default for LocalModelBuilder { fn default() -> Self { - LocalModel { - full_path: PathBuf::new(), - card: ModelDeploymentCard::with_name_only(DEFAULT_NAME), + LocalModelBuilder { + kv_cache_block_size: DEFAULT_KV_CACHE_BLOCK_SIZE, + http_port: DEFAULT_HTTP_PORT, + model_path: Default::default(), + model_name: Default::default(), + model_config: Default::default(), + endpoint_id: Default::default(), + context_length: Default::default(), + template_file: Default::default(), + router_config: Default::default(), } } } -impl LocalModel { - pub fn with_name_only(name: &str) -> Self { - LocalModel { - card: ModelDeploymentCard::with_name_only(name), - ..Default::default() - } +impl LocalModelBuilder { + pub fn model_path(&mut self, model_path: Option) -> &mut Self { + self.model_path = model_path; + self } - pub fn card(&self) -> &ModelDeploymentCard { - &self.card + pub fn model_name(&mut self, model_name: Option) -> &mut Self { + self.model_name = model_name; + self } - pub fn path(&self) -> &Path { - &self.full_path + pub fn model_config(&mut self, model_config: Option) -> &mut Self { + self.model_config = model_config; + self } - pub fn display_name(&self) -> &str { - &self.card.display_name + pub fn endpoint_id(&mut self, endpoint_id: EndpointId) -> &mut Self { + self.endpoint_id = Some(endpoint_id); + self } - pub fn service_name(&self) -> &str { - &self.card.service_name + pub fn context_length(&mut self, context_length: Option) -> &mut Self { + self.context_length = context_length; + self } - pub fn is_gguf(&self) -> bool { - // GGUF is the only file (not-folder) we accept, so we don't need to check the extension - // We will error when we come to parse it - self.full_path.is_file() + /// Passing None resets it to default + pub fn kv_cache_block_size(&mut self, kv_cache_block_size: Option) -> &mut Self { + self.kv_cache_block_size = kv_cache_block_size.unwrap_or(DEFAULT_KV_CACHE_BLOCK_SIZE); + self } - /// Override max number of tokens in context. We usually only do this to limit kv cache allocation. - pub fn set_context_length(&mut self, context_length: usize) { - self.card.context_length = context_length; + pub fn http_port(&mut self, port: u16) -> &mut Self { + self.http_port = port; + self } - pub fn set_kv_cache_block_size(&mut self, block_size: usize) { - self.card.kv_cache_block_size = block_size; + pub fn router_config(&mut self, router_config: RouterConfig) -> &mut Self { + self.router_config = Some(router_config); + self + } + + pub fn request_template(&mut self, template_file: Option) -> &mut Self { + self.template_file = template_file; + self } /// Make an LLM ready for use: @@ -88,28 +120,60 @@ impl LocalModel { /// The model name will depend on what "model_path" is: /// - A folder: The last part of the folder name: "/data/llms/Qwen2.5-3B-Instruct" -> "Qwen2.5-3B-Instruct" /// - A file: The GGUF filename: "/data/llms/Qwen2.5-3B-Instruct-Q6_K.gguf" -> "Qwen2.5-3B-Instruct-Q6_K.gguf" - /// - An HF repo: The HF repo name: "Qwen/Qwen2.5-3B-Instruct" stays the same - pub async fn prepare( - model_path: &str, - override_config: Option<&Path>, - override_name: Option, - ) -> anyhow::Result { - // Name it + /// - An HF repo: The HF repo name: "Qwen/Qwen3-0.6B" stays the same + pub async fn build(&mut self) -> anyhow::Result { + // Generate an endpoint ID for this model if the user didn't provide one. + // The user only provides one if exposing the model. + let endpoint_id = self + .endpoint_id + .take() + .unwrap_or_else(|| internal_endpoint("local_model")); + let template = self + .template_file + .as_deref() + .map(RequestTemplate::load) + .transpose()?; + + // echo_full engine doesn't need a path. It's an edge case, move it out of the way. + if self.model_path.is_none() { + return Ok(LocalModel { + card: ModelDeploymentCard::with_name_only( + self.model_name.as_deref().unwrap_or(DEFAULT_NAME), + ), + full_path: PathBuf::new(), + endpoint_id, + template, + http_port: self.http_port, + // We always have one. The Option is so we can take it. + router_config: self + .router_config + .take() + .expect("unreachable, RouterConfig missing"), + }); + } + + // Main logic. We are running a model. + let model_path = self.model_path.take().unwrap(); + let model_path = model_path.to_str().context("Invalid UTF-8 in model path")?; // Check for hf:// prefix first, in case we really want an HF repo but it conflicts // with a relative path. let is_hf_repo = model_path.starts_with(HF_SCHEME) || !fs::exists(model_path).unwrap_or(false); let relative_path = model_path.trim_start_matches(HF_SCHEME); - let full_path = if is_hf_repo { // HF download if necessary super::hub::from_hf(relative_path).await? } else { fs::canonicalize(relative_path)? }; + // --model-config takes precedence over --model-path + let model_config_path = self.model_config.as_ref().unwrap_or(&full_path); - let model_name = override_name.unwrap_or_else(|| { + let mut card = ModelDeploymentCard::load(&model_config_path).await?; + + // Usually we infer from the path, self.model_name is user override + let model_name = self.model_name.take().unwrap_or_else(|| { if is_hf_repo { // HF repos use their full name ("org/name") not the folder name relative_path.to_string() @@ -124,15 +188,83 @@ impl LocalModel { }) } }); + card.set_name(&model_name); - // Load the ModelDeploymentCard + card.kv_cache_block_size = self.kv_cache_block_size; - // --model-config takes precedence over --model-path - let model_config_path = override_config.unwrap_or(&full_path); - let mut card = ModelDeploymentCard::load(&model_config_path).await?; - card.set_name(&model_name); + // Override max number of tokens in context. We usually only do this to limit kv cache allocation. + if let Some(context_length) = self.context_length { + card.context_length = context_length; + } - Ok(LocalModel { full_path, card }) + Ok(LocalModel { + card, + full_path, + endpoint_id, + template, + http_port: self.http_port, + router_config: self + .router_config + .take() + .expect("unreachable, RouterConfig missing"), + }) + } +} + +#[derive(Debug, Clone)] +pub struct LocalModel { + full_path: PathBuf, + card: ModelDeploymentCard, + endpoint_id: EndpointId, + template: Option, + http_port: u16, // Only used if input is HTTP server + router_config: RouterConfig, +} + +impl LocalModel { + pub fn card(&self) -> &ModelDeploymentCard { + &self.card + } + + pub fn path(&self) -> &Path { + &self.full_path + } + + pub fn display_name(&self) -> &str { + &self.card.display_name + } + + pub fn service_name(&self) -> &str { + &self.card.service_name + } + + pub fn request_template(&self) -> Option { + self.template.clone() + } + + pub fn http_port(&self) -> u16 { + self.http_port + } + + pub fn router_config(&self) -> &RouterConfig { + &self.router_config + } + + pub fn is_gguf(&self) -> bool { + // GGUF is the only file (not-folder) we accept, so we don't need to check the extension + // We will error when we come to parse it + self.full_path.is_file() + } + + /// An endpoint to identify this model by. + pub fn endpoint_id(&self) -> &EndpointId { + &self.endpoint_id + } + + /// Drop the LocalModel returning it's ModelDeploymentCard. + /// For the case where we only need the card and don't want to clone it. + pub fn into_card(self) -> ModelDeploymentCard { + self.card } /// Attach this model the endpoint. This registers it on the network @@ -202,3 +334,13 @@ impl LocalModel { Ok(()) } } + +/// A random endpoint to use for internal communication +/// We can't hard code because we may be running several on the same machine (GPUs 0-3 and 4-7) +fn internal_endpoint(engine: &str) -> EndpointId { + EndpointId { + namespace: Slug::slugify(&uuid::Uuid::new_v4().to_string()).to_string(), + component: engine.to_string(), + name: "generate".to_string(), + } +} diff --git a/lib/llm/src/mocker/kv_manager.rs b/lib/llm/src/mocker/kv_manager.rs index 8a7a8fefed..d1cd4a41ec 100644 --- a/lib/llm/src/mocker/kv_manager.rs +++ b/lib/llm/src/mocker/kv_manager.rs @@ -57,7 +57,7 @@ pub struct KvManager { max_capacity: usize, #[getter(copy)] - block_size: usize, + block_size: u32, active_blocks: HashMap, @@ -67,7 +67,7 @@ pub struct KvManager { } impl KvManager { - pub fn new(max_capacity: usize, block_size: usize) -> Self { + pub fn new(max_capacity: usize, block_size: u32) -> Self { let active_blocks = HashMap::new(); let inactive_blocks = LRUEvictor::default(); let all_blocks = HashSet::new(); @@ -245,7 +245,7 @@ impl KvManager { let overlap_blocks = unique_blocks.len() - new_blocks; // Calculate new tokens - let new_tokens = sequence.num_input_tokens() - overlap_blocks * self.block_size; + let new_tokens = sequence.num_input_tokens() - overlap_blocks * (self.block_size as usize); // // Print the full equation with actual values substituted // println!("{} = {} - ({} * {}) (new_tokens = num_input_tokens - overlap_blocks * block_size)", @@ -261,7 +261,7 @@ impl KvManager { // Calculate prefill compute let prefill_compute = - new_tokens as f64 * (new_tokens + overlap_blocks * self.block_size) as f64; + new_tokens as f64 * (new_tokens + overlap_blocks * (self.block_size as usize)) as f64; Some(PrefillCost { new_tokens, diff --git a/lib/llm/src/mocker/scheduler.rs b/lib/llm/src/mocker/scheduler.rs index 2abd19f24f..604a6e1589 100644 --- a/lib/llm/src/mocker/scheduler.rs +++ b/lib/llm/src/mocker/scheduler.rs @@ -193,7 +193,7 @@ impl Scheduler { pub fn new( kv_capacity: usize, watermark: f64, - block_size: usize, + block_size: u32, chunk_size: Option, output_tx: Option>, cancellation_token: Option, @@ -272,7 +272,7 @@ impl Scheduler { let mut kv_manager_guard = kv_manager_clone.lock().await; // Base time needed for decoding (assumed memory bound on KV cache) - let active_tokens = kv_manager_guard.num_active_blocks() * block_size; + let active_tokens = kv_manager_guard.num_active_blocks() * (block_size as usize); // TODO: 2 is a dummy / magic scaling factor let mut generation_time = Duration::from_micros((active_tokens / 2) as u64); @@ -406,7 +406,7 @@ impl Scheduler { } /// Convert a Request to an ActiveSequence -fn get_active_sequence(request: Request, block_size: usize, chunk_size: usize) -> ActiveSequence { +fn get_active_sequence(request: Request, block_size: u32, chunk_size: usize) -> ActiveSequence { if let Request::Active(active_seq) = request { return active_seq; } @@ -475,7 +475,7 @@ mod tests { let kv_capacity: usize = 500; let watermark: f64 = 0.01; // 1% watermark - let block_size: usize = 64; + let block_size: u32 = 64; let chunk_size: usize = 256; let num_requests: usize = 100; let input_len: usize = 1000; diff --git a/lib/llm/src/mocker/sequence.rs b/lib/llm/src/mocker/sequence.rs index d53dd870e1..e8900fae2c 100644 --- a/lib/llm/src/mocker/sequence.rs +++ b/lib/llm/src/mocker/sequence.rs @@ -23,7 +23,7 @@ use uuid; fn create_unique_blocks_from_sequence( tokens: &TokenBlockSequence, uuid: Option, - block_size: usize, + block_size: u32, ) -> Vec { let mut unique_blocks: Vec = tokens .blocks() @@ -32,7 +32,7 @@ fn create_unique_blocks_from_sequence( .collect(); // Only push the partial block if tokens count isn't a multiple of block_size - if tokens.total_tokens() % block_size != 0 { + if tokens.total_tokens() % (block_size as usize) != 0 { unique_blocks.push(match uuid { Some(uuid) => UniqueBlock::PartialBlock(uuid), None => UniqueBlock::default(), @@ -50,7 +50,7 @@ pub struct ActiveSequence { tokens: TokenBlockSequence, #[getter(copy)] - block_size: usize, + block_size: u32, #[getter(copy)] chunk_size: usize, // TODO: not actually used @@ -72,7 +72,7 @@ impl ActiveSequence { pub fn new( tokens: Vec, max_output_tokens: usize, - block_size: Option, + block_size: Option, chunk_size: Option, ) -> Self { let block_size = block_size.unwrap_or(64); @@ -96,8 +96,8 @@ impl ActiveSequence { } } - pub fn extra_tokens(&self) -> usize { - self.len() % self.block_size + pub fn extra_tokens(&self) -> u32 { + (self.len() % self.block_size as usize) as u32 } pub fn len(&self) -> usize { @@ -112,7 +112,7 @@ impl ActiveSequence { pub fn new_with_signal( tokens: Vec, max_output_tokens: usize, - block_size: Option, + block_size: Option, chunk_size: Option, ) -> (Self, Option) { let mut sequence = Self::new(tokens, max_output_tokens, block_size, chunk_size); @@ -125,7 +125,7 @@ impl ActiveSequence { self.tokens.append(token).expect("Token push failed."); self.generated_tokens += 1; - if self.len() % self.block_size != 1 { + if self.len() % (self.block_size as usize) != 1 { return None; } @@ -223,7 +223,7 @@ impl ActiveSequence { self.generated_tokens = self.generated_tokens.saturating_sub(1); // Reverts to the last full block - if self.tokens.total_tokens() % self.block_size == 0 { + if self.tokens.total_tokens() % (self.block_size as usize) == 0 { self.unique_blocks.pop(); } } @@ -285,7 +285,7 @@ mod tests { // Verify state after pushing tokens assert_eq!(seq1.unique_blocks().len(), 2); // One full block and one partial block assert_eq!(seq1.len(), 17); - assert_eq!(seq1.len() % seq1.block_size(), 1); + assert_eq!(seq1.len() % (seq1.block_size() as usize), 1); // Create another sequence with block size 16 initialized with tokens [0..17] let extended_tokens: Vec = (0..16).collect(); @@ -335,12 +335,12 @@ mod tests { "seq2 should have exactly 3 blocks" ); assert_eq!( - seq1.len() % seq1.block_size(), + seq1.len() % (seq1.block_size() as usize), 1, "seq1 should have 1 partial token" ); assert_eq!( - seq2.len() % seq2.block_size(), + seq2.len() % (seq2.block_size() as usize), 1, "seq2 should have 1 partial token" ); diff --git a/lib/llm/src/model_card/create.rs b/lib/llm/src/model_card/create.rs index c4a84d94a2..99df2b2c5c 100644 --- a/lib/llm/src/model_card/create.rs +++ b/lib/llm/src/model_card/create.rs @@ -76,7 +76,7 @@ impl ModelDeploymentCard { let content = super::model::load_gguf(gguf_file)?; let context_length = content.get_metadata()[&format!("{}.context_length", content.arch())] .to_u32() - .unwrap_or(0) as usize; + .unwrap_or(0); tracing::debug!(context_length, "Loaded context length from GGUF"); Ok(Self { diff --git a/lib/llm/src/model_card/model.rs b/lib/llm/src/model_card/model.rs index a3f9a7c7ec..031d2baf41 100644 --- a/lib/llm/src/model_card/model.rs +++ b/lib/llm/src/model_card/model.rs @@ -117,11 +117,11 @@ pub struct ModelDeploymentCard { pub revision: u64, /// Max context (in number of tokens) this model can handle - pub context_length: usize, + pub context_length: u32, /// Size of a KV cache block - vllm only currently /// Passed to the engine and the KV router. - pub kv_cache_block_size: usize, + pub kv_cache_block_size: u32, } impl ModelDeploymentCard { diff --git a/lib/llm/src/tokenizers/sp.rs b/lib/llm/src/tokenizers/sp.rs index 2098f881f7..6eaa4b3647 100644 --- a/lib/llm/src/tokenizers/sp.rs +++ b/lib/llm/src/tokenizers/sp.rs @@ -81,7 +81,7 @@ impl Decoder for SentencePieceTokenizer { /// # Arguments /// * `token_ids` - The sequence of token IDs to decode /// * `skip_special_tokens` - Currently unsupported in SentencePieceTokenizer and - /// it will return an error if true + /// it will return an error if true /// /// # Returns /// * `Result` - The decoded text diff --git a/lib/llm/src/tokens.rs b/lib/llm/src/tokens.rs index 645919d1bd..db55530185 100644 --- a/lib/llm/src/tokens.rs +++ b/lib/llm/src/tokens.rs @@ -155,11 +155,7 @@ impl Tokens { /// /// * `block_size` - The fixed size for each [`TokenBlock`]. /// * `salt_hash` - An optional [`SaltHash`] used as the base seed for hashing. Defaults to 0. - pub fn into_sequence( - self, - block_size: usize, - salt_hash: Option, - ) -> TokenBlockSequence { + pub fn into_sequence(self, block_size: u32, salt_hash: Option) -> TokenBlockSequence { TokenBlockSequence::new(self, block_size, salt_hash) } } @@ -191,7 +187,7 @@ pub enum TokenBlockError { #[derive(Debug, PartialEq)] // No Clone: intended to be unique within a sequence pub struct PartialTokenBlock { tokens: Tokens, - block_size: usize, + block_size: u32, salt_hash: SaltHash, parent_sequence_hash: Option, } @@ -203,7 +199,7 @@ impl PartialTokenBlock { /// /// * `block_size` - The fixed size for blocks in this sequence. /// * `salt_hash` - The [`SaltHash`] for the sequence. - pub(crate) fn create_sequence_root(block_size: usize, salt_hash: SaltHash) -> Self { + pub(crate) fn create_sequence_root(block_size: u32, salt_hash: SaltHash) -> Self { Self { tokens: Tokens::default(), block_size, @@ -223,7 +219,7 @@ impl PartialTokenBlock { /// * `Ok(())` - If the token was successfully added. /// * `Err(TokenBlockError::Full)` - If the block already contains `block_size` tokens. pub(crate) fn push_token(&mut self, token: Token) -> Result<(), TokenBlockError> { - if self.tokens.0.len() >= self.block_size { + if self.tokens.0.len() >= self.block_size as usize { return Err(TokenBlockError::Full); } self.tokens.0.push(token); @@ -305,7 +301,7 @@ impl PartialTokenBlock { /// * `Ok(TokenBlock)` - The newly created full [`TokenBlock`]. /// * `Err(TokenBlockError::Incomplete)` - If the block does not contain exactly `block_size` tokens. pub(crate) fn commit(&mut self) -> Result { - if self.tokens.0.len() != self.block_size { + if self.tokens.0.len() != self.block_size as usize { // Check for exact size match for committing return Err(TokenBlockError::Incomplete); } @@ -327,7 +323,7 @@ impl PartialTokenBlock { /// Returns the number of additional tokens required to fill the block. pub fn remaining(&self) -> usize { // Use saturating_sub to prevent underflow if len somehow exceeds block_size - self.block_size.saturating_sub(self.tokens.0.len()) + (self.block_size as usize).saturating_sub(self.tokens.0.len()) } /// Returns the number of tokens currently in the block. @@ -408,7 +404,7 @@ impl TokenBlock { pub fn next_block(&self) -> PartialTokenBlock { PartialTokenBlock { tokens: Tokens::default(), - block_size: self.tokens.len(), // Should be == self.block_size + block_size: self.tokens.len() as u32, // Should be == self.block_size salt_hash: self.salt_hash, parent_sequence_hash: Some(self.sequence_hash), // Link to this block } @@ -500,7 +496,7 @@ impl TokenBlockSequence { /// # Panics /// /// Panics if `block_size` is 0. - pub fn new(tokens: Tokens, block_size: usize, salt_hash: Option) -> Self { + pub fn new(tokens: Tokens, block_size: u32, salt_hash: Option) -> Self { assert!(block_size > 0, "block_size must be greater than 0"); let salt_hash = salt_hash.unwrap_or(0); let (blocks, current_block) = Self::split_tokens(&tokens, block_size, salt_hash); @@ -640,7 +636,7 @@ impl TokenBlockSequence { let tokens_to_pop_from_blocks = n - current_len; // Calculate how many blocks are affected (including the one partially popped) - let num_blocks_to_affect = tokens_to_pop_from_blocks.div_ceil(block_size); + let num_blocks_to_affect = tokens_to_pop_from_blocks.div_ceil(block_size as usize); // Check if we need to pop more blocks than available (should be prevented by initial len check) if num_blocks_to_affect > self.blocks.len() { @@ -657,10 +653,10 @@ impl TokenBlockSequence { // Calculate how many tokens to keep from that source block let num_full_blocks_completely_popped = num_blocks_to_affect - 1; - let num_tokens_to_pop_from_source_block = - tokens_to_pop_from_blocks - num_full_blocks_completely_popped * block_size; + let num_tokens_to_pop_from_source_block = tokens_to_pop_from_blocks + - num_full_blocks_completely_popped * block_size as usize; let num_tokens_to_keep_in_new_partial = - block_size.saturating_sub(num_tokens_to_pop_from_source_block); + (block_size as usize).saturating_sub(num_tokens_to_pop_from_source_block); // Get the tokens for the new partial block let new_partial_tokens = if num_tokens_to_keep_in_new_partial > 0 { @@ -789,7 +785,7 @@ impl TokenBlockSequence { /// Returns the total number of tokens in the sequence (sum of tokens in all completed blocks /// plus tokens in the current partial block). pub fn total_tokens(&self) -> usize { - let block_size = self.current_block.block_size; + let block_size = self.current_block.block_size as usize; (self.blocks.len() * block_size) + self.current_block.len() } @@ -812,14 +808,14 @@ impl TokenBlockSequence { /// Panics if `block_size` is 0. pub fn split_tokens( tokens: &[Token], - block_size: usize, + block_size: u32, salt_hash: u64, ) -> (Vec, PartialTokenBlock) { assert!(block_size > 0, "block_size must be greater than 0"); // Use Rayon for parallel computation of block chunks (hashes) let chunks: Vec = tokens .as_ref() - .par_chunks_exact(block_size) + .par_chunks_exact(block_size as usize) .map(|chunk| TokenBlockChunk::from_tokens(chunk, salt_hash)) .collect(); @@ -834,7 +830,10 @@ impl TokenBlockSequence { } // Handle any remaining tokens - let remainder = tokens.as_ref().chunks_exact(block_size).remainder(); + let remainder = tokens + .as_ref() + .chunks_exact(block_size as usize) + .remainder(); let current_block = PartialTokenBlock { tokens: remainder.into(), @@ -856,7 +855,7 @@ mod tests { // Helper to create a sequence for testing fn create_test_sequence( initial_tokens: &[Token], - block_size: usize, + block_size: u32, salt_hash: Option, ) -> TokenBlockSequence { TokenBlockSequence::new(Tokens::from(initial_tokens), block_size, salt_hash)