diff --git a/Cargo.lock b/Cargo.lock index 1f60c88f60..5bafcee522 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -747,6 +747,8 @@ dependencies = [ "cc", "cfg-if 1.0.3", "constant_time_eq", + "memmap2", + "rayon-core", ] [[package]] diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index d99526e08d..690a69019d 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -611,6 +611,8 @@ dependencies = [ "cc", "cfg-if 1.0.3", "constant_time_eq", + "memmap2", + "rayon-core", ] [[package]] diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index 3174b9523d..4b7abb040e 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -94,7 +94,7 @@ xxhash-rust = { workspace = true } akin = "0.4.0" bitflags = { version = "2.4", features = ["serde"] } -blake3 = "1" +blake3 = { version = "1.8", features=["mmap", "rayon"] } bytemuck = "1.22" candle-core = { version = "0.9.1" } derive-getters = "0.5" diff --git a/lib/llm/src/common.rs b/lib/llm/src/common.rs index 2d97064bd1..9d115fcb7b 100644 --- a/lib/llm/src/common.rs +++ b/lib/llm/src/common.rs @@ -1,17 +1,6 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +pub mod checked_file; pub mod dtype; pub mod versioned; diff --git a/lib/llm/src/common/checked_file.rs b/lib/llm/src/common/checked_file.rs new file mode 100644 index 0000000000..8cacd48b84 --- /dev/null +++ b/lib/llm/src/common/checked_file.rs @@ -0,0 +1,330 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::{ + fmt::Display, + path::{Path, PathBuf}, + str::FromStr, +}; + +use either::Either; +use serde::{ + Deserialize, Deserializer, Serialize, Serializer, + de::{self, Visitor}, + ser::SerializeStruct as _, +}; +use url::Url; + +#[derive(Clone, Debug)] +pub struct CheckedFile { + /// Either a path on local disk or a remote URL (usually nats object store) + path: Either, + + /// Checksum of the contents of path + checksum: Checksum, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Checksum { + /// The checksum is a hex encoded string of the file's content + hash: String, + + /// Checksum algorithm + algorithm: CryptographicHashMethods, +} + +#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)] +pub enum CryptographicHashMethods { + #[serde(rename = "blake3")] + BLAKE3, +} + +impl CheckedFile { + pub fn from_disk>(filepath: P) -> anyhow::Result { + let path: PathBuf = filepath.into(); + if !path.exists() { + anyhow::bail!("File not found: {}", path.display()); + } + if !path.is_file() { + anyhow::bail!("Not a file: {}", path.display()); + } + let hash = b3sum(&path)?; + + Ok(CheckedFile { + path: Either::Left(path), + checksum: Checksum::blake3(hash), + }) + } + + /// Replace the local disk path with a remote URL. + /// Just updates the field, doesn't move any files. + pub fn move_to_url(&mut self, u: url::Url) { + self.path = Either::Right(u); + } + + /// Replace a remove URL with local disk path. + /// Just updates the field, doesn't move any files. + pub fn move_to_disk>(&mut self, p: P) { + self.path = Either::Left(p.into()); + } + + pub fn path(&self) -> Option<&Path> { + match self.path.as_ref() { + Either::Left(p) => Some(p), + Either::Right(_) => None, + } + } + + pub fn url(&self) -> Option<&Url> { + match self.path.as_ref() { + Either::Left(_) => None, + Either::Right(u) => Some(u), + } + } + + pub fn is_nats_url(&self) -> bool { + matches!(self.path.as_ref(), Either::Right(u) if u.scheme() == "nats") + } + + pub fn checksum(&self) -> &Checksum { + &self.checksum + } + + /// Does the given file checksum to the same value as this CheckedFile? + pub fn checksum_matches + std::fmt::Debug>(&self, disk_file: P) -> bool { + match b3sum(&disk_file) { + Ok(h) => Checksum::blake3(h) == self.checksum, + Err(error) => { + tracing::error!(disk_file = %disk_file.as_ref().display(), checked_file = self.to_string(), %error, "Checksum does not match"); + false + } + } + } +} + +impl Display for CheckedFile { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let p = match &self.path { + Either::Left(local) => local.display().to_string(), + Either::Right(url) => url.to_string(), + }; + write!(f, "({p}, {})", self.checksum) + } +} + +impl Serialize for CheckedFile { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut cf = serializer.serialize_struct("CheckedFile", 2)?; + match &self.path { + Either::Left(path) => cf.serialize_field("path", &path)?, + Either::Right(url) => cf.serialize_field("path", &url)?, + }; + cf.serialize_field("checksum", &self.checksum)?; + cf.end() + } +} + +/// Internal type to simplify deserializing +#[derive(Deserialize)] +struct WireCheckedFile { + path: String, + checksum: Checksum, +} + +// Convert from the temporary struct to CheckedFile with path type logic. +impl From for CheckedFile { + fn from(temp: WireCheckedFile) -> Self { + // Try to parse as a URL; if successful, use Either::Right(Url), else use Either::Left(PathBuf). + match Url::parse(&temp.path) { + Ok(url) => CheckedFile { + path: Either::Right(url), + checksum: temp.checksum, + }, + Err(_) => CheckedFile { + path: Either::Left(PathBuf::from(temp.path)), + checksum: temp.checksum, + }, + } + } +} + +// Implement Deserialize for CheckedFile using the temporary struct. +impl<'de> Deserialize<'de> for CheckedFile { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // Deserialize into WireCheckedFile, then convert to CheckedFile. + let temp = WireCheckedFile::deserialize(deserializer)?; + Ok(CheckedFile::from(temp)) + } +} + +fn b3sum + std::fmt::Debug>(path: T) -> anyhow::Result { + let path = path.as_ref(); + let metadata = std::fs::metadata(path)?; + let filesize = metadata.len(); + let mut hasher = blake3::Hasher::new(); + + if filesize > 128_000 { + // multithreaded. blake3 recommend this above 128 KiB. + hasher.update_mmap_rayon(path)?; + } else { + // Uses mmap above 16 KiB, normal load otherwise. + hasher.update_mmap(path)?; + } + + let hash = hasher.finalize(); + Ok(hash.to_string()) +} + +impl Checksum { + pub fn blake3(hash: impl Into) -> Self { + Self::new(hash, CryptographicHashMethods::BLAKE3) + } + + pub fn new(hash: impl Into, algorithm: CryptographicHashMethods) -> Self { + Self { + hash: hash.into(), + algorithm, + } + } +} + +impl Serialize for Checksum { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let serialized_str = format!("{}:{}", self.algorithm, self.hash); + serializer.serialize_str(&serialized_str) + } +} + +impl<'de> Deserialize<'de> for Checksum { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ChecksumVisitor; + + impl Visitor<'_> for ChecksumVisitor { + type Value = Checksum; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a string in the format `{algo}:{hash}`") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + let parts: Vec<&str> = value.split(':').collect(); + if parts.len() != 2 { + return Err(de::Error::invalid_value(de::Unexpected::Str(value), &self)); + } + + let algorithm = parts[0].parse().map_err(|_| { + de::Error::invalid_value(de::Unexpected::Str(parts[0]), &"invalid algorithm") + })?; + + Ok(Checksum::new(parts[1], algorithm)) + } + } + + deserializer.deserialize_str(ChecksumVisitor) + } +} + +impl TryFrom<&str> for Checksum { + type Error = anyhow::Error; + + fn try_from(value: &str) -> Result { + let parts: Vec<&str> = value.split(':').collect(); + if parts.len() != 2 { + anyhow::bail!("Invalid checksum format; expect `algo:hash`; got: {value}"); + } + + let algo = match parts[0] { + "blake3" => CryptographicHashMethods::BLAKE3, + _ => { + anyhow::bail!("Unsupported cryptographic hash method: {}", parts[0]); + } + }; + + Ok(Checksum::new(parts[1], algo)) + } +} + +impl FromStr for CryptographicHashMethods { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "blake3" => Ok(CryptographicHashMethods::BLAKE3), + _ => Err(format!("Unsupported algorithm: {}", s)), + } + } +} + +impl Display for CryptographicHashMethods { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CryptographicHashMethods::BLAKE3 => write!(f, "blake3"), + } + } +} + +impl Display for Checksum { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}:{}", self.algorithm, self.hash) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_serialization_blake3() { + let checksum = Checksum::blake3("a12c3d4"); + + let serialized = serde_json::to_string(&checksum).unwrap(); + assert_eq!(serialized.trim(), "\"blake3:a12c3d4\""); + } + + #[test] + fn test_deserialization_blake3() { + let s = "\"blake3:abcd1234\""; + let deserialized: Checksum = serde_json::from_str(s).unwrap(); + + assert_eq!(deserialized.algorithm, CryptographicHashMethods::BLAKE3); + assert_eq!(deserialized.hash, "abcd1234"); + } + + #[test] + fn test_deserialization_invalid_format() { + let s = "\"invalidformat\""; + let result: Result = serde_json::from_str(s); + + assert!(result.is_err()); + + let s = "\"blake3:invalid:format\""; + let result: Result = serde_json::from_str(s); + + assert!(result.is_err()); + } + + #[test] + fn test_checked_file_from_disk() { + let root = env!("CARGO_MANIFEST_DIR"); // ${WORKSPACE}/lib/llm + let full_path = format!("{root}/tests/data/sample-models/TinyLlama_v1.1/config.json"); + let cf = CheckedFile::from_disk(full_path).unwrap(); + let expected = + Checksum::blake3("62bc124be974d3a25db05bedc99422660c26715e5bbda0b37d14bd84a0c65ab2"); + assert_eq!(expected, *cf.checksum()); + } +} diff --git a/lib/llm/src/kv_router/subscriber.rs b/lib/llm/src/kv_router/subscriber.rs index 97d65e5e45..b161db8d61 100644 --- a/lib/llm/src/kv_router/subscriber.rs +++ b/lib/llm/src/kv_router/subscriber.rs @@ -119,7 +119,7 @@ pub async fn start_kv_router_background( ))?; match nats_client - .object_store_download_data::>(url) + .object_store_download_data::>(&url) .await { Ok(events) => { @@ -353,7 +353,7 @@ async fn perform_snapshot_and_purge( resources .nats_client - .object_store_upload_data(&events, url) + .object_store_upload_data(&events, &url) .await .map_err(|e| anyhow::anyhow!("Failed to upload snapshot: {e:?}"))?; diff --git a/lib/llm/src/model_card.rs b/lib/llm/src/model_card.rs index 93e0ad09e4..01fb116b90 100644 --- a/lib/llm/src/model_card.rs +++ b/lib/llm/src/model_card.rs @@ -19,13 +19,13 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::Duration; +use crate::common::checked_file::CheckedFile; use crate::local_model::runtime_config::ModelRuntimeConfig; use anyhow::{Context, Result}; use derive_builder::Builder; use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats}; use serde::{Deserialize, Serialize}; use tokenizers::Tokenizer as HfTokenizer; -use url::Url; use crate::gguf::{Content, ContentConfig, ModelConfigLike}; use crate::protocols::TokenIdType; @@ -39,14 +39,14 @@ const CARD_MAX_AGE: chrono::TimeDelta = chrono::TimeDelta::minutes(5); #[derive(Serialize, Deserialize, Clone, Debug)] #[serde(rename_all = "snake_case")] pub enum ModelInfoType { - HfConfigJson(String), + HfConfigJson(CheckedFile), GGUF(PathBuf), } #[derive(Serialize, Deserialize, Clone, Debug)] #[serde(rename_all = "snake_case")] pub enum TokenizerKind { - HfTokenizerJson(String), + HfTokenizerJson(CheckedFile), GGUF(Box), } @@ -65,8 +65,8 @@ pub enum TokenizerKind { #[derive(Serialize, Deserialize, Clone, Debug)] #[serde(rename_all = "snake_case")] pub enum PromptFormatterArtifact { - HfTokenizerConfigJson(String), - HfChatTemplate(String), + HfTokenizerConfigJson(CheckedFile), + HfChatTemplate(CheckedFile), GGUF(PathBuf), } @@ -83,7 +83,7 @@ pub enum PromptContextMixin { #[derive(Serialize, Deserialize, Clone, Debug)] #[serde(rename_all = "snake_case")] pub enum GenerationConfig { - HfGenerationConfigJson(String), + HfGenerationConfigJson(CheckedFile), GGUF(PathBuf), } @@ -223,8 +223,11 @@ impl ModelDeploymentCard { pub fn tokenizer_hf(&self) -> anyhow::Result { match &self.tokenizer { - Some(TokenizerKind::HfTokenizerJson(file)) => { - HfTokenizer::from_file(file).map_err(anyhow::Error::msg) + Some(TokenizerKind::HfTokenizerJson(checked_file)) => { + let p = checked_file.path().ok_or_else(|| + anyhow::anyhow!("Tokenizer is URL-backed ({:?}); call move_from_nats() before tokenizer_hf()", checked_file.url()) + )?; + HfTokenizer::from_file(p).map_err(anyhow::Error::msg) } Some(TokenizerKind::GGUF(t)) => Ok(*t.clone()), None => { @@ -253,22 +256,23 @@ impl ModelDeploymentCard { macro_rules! nats_upload { ($field:expr, $enum_variant:path, $filename:literal) => { - if let Some($enum_variant(src_file)) = $field.take() { - if !nats::is_nats_url(&src_file) { - let target = format!("nats://{nats_addr}/{bucket_name}/{}", $filename); - nats_client - .object_store_upload( - &std::path::PathBuf::from(&src_file), - url::Url::parse(&target)?, - ) - .await?; - $field = Some($enum_variant(target)); - } + if let Some($enum_variant(src_file)) = $field.as_mut() + && let Some(path) = src_file.path() + { + let target = format!("nats://{nats_addr}/{bucket_name}/{}", $filename); + let dest = url::Url::parse(&target)?; + nats_client.object_store_upload(path, &dest).await?; + src_file.move_to_url(dest); } }; } nats_upload!(self.model_info, ModelInfoType::HfConfigJson, "config.json"); + nats_upload!( + self.gen_config, + GenerationConfig::HfGenerationConfigJson, + "generation_config.json" + ); nats_upload!( self.prompt_formatter, PromptFormatterArtifact::HfTokenizerConfigJson, @@ -284,11 +288,6 @@ impl ModelDeploymentCard { TokenizerKind::HfTokenizerJson, "tokenizer.json" ); - nats_upload!( - self.gen_config, - GenerationConfig::HfGenerationConfigJson, - "generation_config.json" - ); Ok(()) } @@ -310,19 +309,29 @@ impl ModelDeploymentCard { macro_rules! nats_download { ($field:expr, $enum_variant:path, $filename:literal) => { - if let Some($enum_variant(src_url)) = $field.take() { - if nats::is_nats_url(&src_url) { - let target = target_dir.path().join($filename); - nats_client - .object_store_download(Url::parse(&src_url)?, &target) - .await?; - $field = Some($enum_variant(target.display().to_string())); + if let Some($enum_variant(src_file)) = $field.as_mut() + && let Some(src_url) = src_file.url() + { + let target = target_dir.path().join($filename); + nats_client.object_store_download(src_url, &target).await?; + if !src_file.checksum_matches(&target) { + anyhow::bail!( + "Invalid {} in NATS for {}, checksum does not match.", + $filename, + self.display_name + ); } + src_file.move_to_disk(target); } }; } nats_download!(self.model_info, ModelInfoType::HfConfigJson, "config.json"); + nats_download!( + self.gen_config, + GenerationConfig::HfGenerationConfigJson, + "generation_config.json" + ); nats_download!( self.prompt_formatter, PromptFormatterArtifact::HfTokenizerConfigJson, @@ -338,11 +347,6 @@ impl ModelDeploymentCard { TokenizerKind::HfTokenizerJson, "tokenizer.json" ); - nats_download!( - self.gen_config, - GenerationConfig::HfGenerationConfigJson, - "generation_config.json" - ); Ok(target_dir) } @@ -499,7 +503,7 @@ impl ModelDeploymentCard { })?; Some(PromptFormatterArtifact::HfChatTemplate( - template_path.display().to_string(), + CheckedFile::from_disk(template_path)?, )) } else { PromptFormatterArtifact::chat_template_from_repo(repo_id)? @@ -563,8 +567,13 @@ pub trait ModelInfo: Send + Sync { impl ModelInfoType { pub fn get_model_info(&self) -> Result> { match self { - Self::HfConfigJson(info) => HFConfig::from_json_file(info), - Self::GGUF(path) => HFConfig::from_gguf(path), + Self::HfConfigJson(checked_file) => { + let Some(path) = checked_file.path() else { + anyhow::bail!("model info is not a local path: {checked_file:?}"); + }; + Ok(HFConfig::from_json_file(path)?) + } + Self::GGUF(path) => Ok(HFConfig::from_gguf(path)?), } } pub fn is_gguf(&self) -> bool { @@ -615,9 +624,9 @@ struct HFTextConfig { } impl HFConfig { - fn from_json_file(file: &str) -> Result> { - let file_pathbuf = PathBuf::from(file); - let contents = std::fs::read_to_string(file)?; + fn from_json_file>(file: P) -> Result> { + let file_path = file.as_ref(); + let contents = std::fs::read_to_string(file_path)?; let mut config: Self = serde_json::from_str(&contents)?; if config.text_config.is_none() { let text_config: HFTextConfig = serde_json::from_str(&contents)?; @@ -630,17 +639,15 @@ impl HFConfig { ); }; + let gencfg_path = file_path + .parent() + .unwrap_or_else(|| Path::new("")) + .join("generation_config.json"); if text_config.bos_token_id.is_none() { - let bos_token_id = crate::file_json_field::( - &Path::join( - file_pathbuf.parent().unwrap_or(&PathBuf::from("")), - "generation_config.json", - ), - "bos_token_id", - ) - .context( - "missing bos_token_id in generation_config.json and config.json, cannot load", - )?; + let bos_token_id = crate::file_json_field::(&gencfg_path, "bos_token_id") + .context( + "missing bos_token_id in generation_config.json and config.json, cannot load", + )?; text_config.bos_token_id = Some(bos_token_id); } // Now that we have it for sure, set it in the non-Option field @@ -672,7 +679,7 @@ impl HFConfig { } else { tracing::error!( ?v, - file, + path = %file_path.display(), "eos_token_id is not a number or an array, cannot use" ); None @@ -680,13 +687,7 @@ impl HFConfig { }) .or_else(|| { // Maybe it's in generation_config.json - crate::file_json_field( - &Path::join( - file_pathbuf.parent().unwrap_or(&PathBuf::from("")), - "generation_config.json", - ), - "eos_token_id", - ) + crate::file_json_field(&gencfg_path, "eos_token_id") .inspect_err( |err| tracing::warn!(%err, "Missing eos_token_id in generation_config.json"), ) @@ -794,12 +795,17 @@ fn capitalize(s: &str) -> String { impl ModelInfoType { pub fn from_repo(repo_id: &str) -> Result { - Self::try_is_hf_repo(repo_id) - .with_context(|| format!("unable to extract model info from repo {}", repo_id)) + let f = CheckedFile::from_disk(PathBuf::from(repo_id).join("config.json")) + .with_context(|| format!("unable to extract config.json from repo {repo_id}"))?; + Ok(Self::HfConfigJson(f)) } +} - fn try_is_hf_repo(repo: &str) -> anyhow::Result { - Ok(Self::HfConfigJson(check_for_file(repo, "config.json")?)) +impl GenerationConfig { + pub fn from_repo(repo_id: &str) -> Result { + let f = CheckedFile::from_disk(PathBuf::from(repo_id).join("generation_config.json")) + .with_context(|| format!("unable to extract generation_config from repo {repo_id}"))?; + Ok(Self::HfGenerationConfigJson(f)) } } @@ -807,68 +813,26 @@ impl PromptFormatterArtifact { pub fn from_repo(repo_id: &str) -> Result> { // we should only error if we expect a prompt formatter and it's not found // right now, we don't know when to expect it, so we just return Ok(Some/None) - Ok(Self::try_is_hf_repo(repo_id) - .with_context(|| format!("unable to extract prompt format from repo {}", repo_id)) - .ok()) + match CheckedFile::from_disk(PathBuf::from(repo_id).join("tokenizer_config.json")) { + Ok(f) => Ok(Some(Self::HfTokenizerConfigJson(f))), + Err(_) => Ok(None), + } } pub fn chat_template_from_repo(repo_id: &str) -> Result> { - Ok(Self::chat_template_try_is_hf_repo(repo_id) - .with_context(|| format!("unable to extract prompt format from repo {}", repo_id)) - .ok()) - } - - fn chat_template_try_is_hf_repo(repo: &str) -> anyhow::Result { - Ok(Self::HfChatTemplate(check_for_file( - repo, - "chat_template.jinja", - )?)) - } - - fn try_is_hf_repo(repo: &str) -> anyhow::Result { - Ok(Self::HfTokenizerConfigJson(check_for_file( - repo, - "tokenizer_config.json", - )?)) + match CheckedFile::from_disk(PathBuf::from(repo_id).join("chat_template.jinja")) { + Ok(f) => Ok(Some(Self::HfChatTemplate(f))), + Err(_) => Ok(None), + } } } impl TokenizerKind { pub fn from_repo(repo_id: &str) -> Result { - Self::try_is_hf_repo(repo_id) - .with_context(|| format!("unable to extract tokenizer kind from repo {}", repo_id)) - } - - fn try_is_hf_repo(repo: &str) -> anyhow::Result { - Ok(Self::HfTokenizerJson(check_for_file( - repo, - "tokenizer.json", - )?)) - } -} - -impl GenerationConfig { - pub fn from_repo(repo_id: &str) -> Result { - Self::try_is_hf_repo(repo_id) - .with_context(|| format!("unable to extract generation config from repo {repo_id}")) - } - - fn try_is_hf_repo(repo: &str) -> anyhow::Result { - Ok(Self::HfGenerationConfigJson(check_for_file( - repo, - "generation_config.json", - )?)) - } -} - -/// Checks if the provided path contains the expected file. -fn check_for_file(repo_id: &str, file: &str) -> anyhow::Result { - let p = PathBuf::from(repo_id).join(file); - let name = p.display().to_string(); - if !p.exists() { - anyhow::bail!("File not found: {name}") + let f = CheckedFile::from_disk(PathBuf::from(repo_id).join("tokenizer.json")) + .with_context(|| format!("unable to extract tokenizer kind from repo {repo_id}"))?; + Ok(Self::HfTokenizerJson(f)) } - Ok(name) } /// Checks if the provided path is a valid local repository path. @@ -905,7 +869,7 @@ mod tests { pub fn test_config_json_llama3() -> anyhow::Result<()> { let config_file = Path::new(env!("CARGO_MANIFEST_DIR")) .join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json"); - let config = HFConfig::from_json_file(&config_file.display().to_string())?; + let config = HFConfig::from_json_file(&config_file)?; assert_eq!(config.bos_token_id(), 128000); Ok(()) } @@ -914,7 +878,7 @@ mod tests { pub fn test_config_json_llama4() -> anyhow::Result<()> { let config_file = Path::new(env!("CARGO_MANIFEST_DIR")) .join("tests/data/sample-models/Llama-4-Scout-17B-16E-Instruct/config.json"); - let config = HFConfig::from_json_file(&config_file.display().to_string())?; + let config = HFConfig::from_json_file(&config_file)?; assert_eq!(config.bos_token_id(), 200000); Ok(()) } diff --git a/lib/llm/src/preprocessor/prompt/template.rs b/lib/llm/src/preprocessor/prompt/template.rs index e307ba842f..1e47e38589 100644 --- a/lib/llm/src/preprocessor/prompt/template.rs +++ b/lib/llm/src/preprocessor/prompt/template.rs @@ -23,20 +23,34 @@ impl PromptFormatter { .as_ref() .ok_or(anyhow::anyhow!("MDC does not contain a prompt formatter"))? { - PromptFormatterArtifact::HfTokenizerConfigJson(file) => { + PromptFormatterArtifact::HfTokenizerConfigJson(checked_file) => { + let Some(file) = checked_file.path() else { + anyhow::bail!( + "HfTokenizerConfigJson for {} is a URL, cannot load", + mdc.display_name + ); + }; let content = std::fs::read_to_string(file) - .with_context(|| format!("fs:read_to_string '{file}'"))?; + .with_context(|| format!("fs:read_to_string '{}'", file.display()))?; let mut config: ChatTemplate = serde_json::from_str(&content)?; // Some HF model (i.e. meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8) // stores the chat template in a separate file, we check if the file exists and // put the chat template into config as normalization. // This may also be a custom template provided via CLI flag. - if let Some(PromptFormatterArtifact::HfChatTemplate(chat_template_file)) = + if let Some(PromptFormatterArtifact::HfChatTemplate(checked_file)) = mdc.chat_template_file.as_ref() { - let chat_template = std::fs::read_to_string(chat_template_file) - .with_context(|| format!("fs:read_to_string '{}'", chat_template_file))?; + let Some(chat_template_file) = checked_file.path() else { + anyhow::bail!( + "HfChatTemplate for {} is a URL, cannot load", + mdc.display_name + ); + }; + let chat_template = + std::fs::read_to_string(chat_template_file).with_context(|| { + format!("fs:read_to_string '{}'", chat_template_file.display()) + })?; // clean up the string to remove newlines let chat_template = chat_template.replace('\n', ""); config.chat_template = Some(ChatTemplateValue(either::Left(chat_template))); diff --git a/lib/runtime/src/transports/nats.rs b/lib/runtime/src/transports/nats.rs index 03844a36f4..7cf3676db1 100644 --- a/lib/runtime/src/transports/nats.rs +++ b/lib/runtime/src/transports/nats.rs @@ -173,10 +173,10 @@ impl Client { } /// Upload file to NATS at this URL - pub async fn object_store_upload(&self, filepath: &Path, nats_url: Url) -> anyhow::Result<()> { + pub async fn object_store_upload(&self, filepath: &Path, nats_url: &Url) -> anyhow::Result<()> { let mut disk_file = TokioFile::open(filepath).await?; - let (bucket_name, key) = url_to_bucket_and_key(&nats_url)?; + let (bucket_name, key) = url_to_bucket_and_key(nats_url)?; let bucket = self.get_or_create_bucket(&bucket_name, true).await?; let key_meta = async_nats::jetstream::object_store::ObjectMetadata { @@ -193,12 +193,12 @@ impl Client { /// Download file from NATS at this URL pub async fn object_store_download( &self, - nats_url: Url, + nats_url: &Url, filepath: &Path, ) -> anyhow::Result<()> { let mut disk_file = TokioFile::create(filepath).await?; - let (bucket_name, key) = url_to_bucket_and_key(&nats_url)?; + let (bucket_name, key) = url_to_bucket_and_key(nats_url)?; let bucket = self.get_or_create_bucket(&bucket_name, false).await?; let mut obj_reader = bucket.get(&key).await.map_err(|e| { @@ -225,7 +225,7 @@ impl Client { } /// Upload a serializable struct to NATS object store using bincode - pub async fn object_store_upload_data(&self, data: &T, nats_url: Url) -> anyhow::Result<()> + pub async fn object_store_upload_data(&self, data: &T, nats_url: &Url) -> anyhow::Result<()> where T: Serialize, { @@ -233,7 +233,7 @@ impl Client { let binary_data = bincode::serialize(data) .map_err(|e| anyhow::anyhow!("Failed to serialize data with bincode: {e}"))?; - let (bucket_name, key) = url_to_bucket_and_key(&nats_url)?; + let (bucket_name, key) = url_to_bucket_and_key(nats_url)?; let bucket = self.get_or_create_bucket(&bucket_name, true).await?; let key_meta = async_nats::jetstream::object_store::ObjectMetadata { @@ -251,11 +251,11 @@ impl Client { } /// Download and deserialize a struct from NATS object store using bincode - pub async fn object_store_download_data(&self, nats_url: Url) -> anyhow::Result + pub async fn object_store_download_data(&self, nats_url: &Url) -> anyhow::Result where T: DeserializeOwned, { - let (bucket_name, key) = url_to_bucket_and_key(&nats_url)?; + let (bucket_name, key) = url_to_bucket_and_key(nats_url)?; let bucket = self.get_or_create_bucket(&bucket_name, false).await?; let mut obj_reader = bucket.get(&key).await.map_err(|e| { @@ -1078,13 +1078,13 @@ mod tests { // Upload the data client - .object_store_upload_data(&test_data, url.clone()) + .object_store_upload_data(&test_data, &url) .await .expect("Failed to upload data"); // Download the data let downloaded_data: TestData = client - .object_store_download_data(url.clone()) + .object_store_download_data(&url) .await .expect("Failed to download data");