From ee0a8ff9f22da8ae8f1dc73e8c45426a19bdf5e3 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Wed, 19 Nov 2025 15:10:13 +0530 Subject: [PATCH] refactor: Improved Error Handling Signed-off-by: Anush008 --- src/image_embedding/impl.rs | 15 +++++++++++---- src/image_embedding/utils.rs | 32 ++++++++++++++++++++----------- src/models/image_embedding.rs | 2 +- src/models/reranking.rs | 2 +- src/models/sparse.rs | 2 +- src/models/text_embedding.rs | 2 +- src/reranking/impl.rs | 7 +++++-- src/sparse_text_embedding/impl.rs | 12 +++++++++--- src/text_embedding/impl.rs | 5 ++++- src/text_embedding/output.rs | 12 +++++++++--- 10 files changed, 63 insertions(+), 28 deletions(-) diff --git a/src/image_embedding/impl.rs b/src/image_embedding/impl.rs index 304969c..9087a65 100644 --- a/src/image_embedding/impl.rs +++ b/src/image_embedding/impl.rs @@ -122,7 +122,7 @@ impl ImageEmbedding { ImageEmbedding::list_supported_models() .into_iter() .find(|m| &m.model == model) - .expect("Model not found.") + .expect("Model not found in supported models list. This is a bug - please report it.") } /// Method to generate image embeddings for a Vec of image bytes @@ -217,7 +217,10 @@ impl ImageEmbedding { // Try to get the only output key // If multiple, then default to few known keys `image_embeds` and `last_hidden_state` let last_hidden_state_key = match outputs.len() { - 1 => vec![outputs.keys().next().unwrap()], + 1 => vec![outputs + .keys() + .next() + .ok_or_else(|| anyhow!("Expected one output but found none"))?], _ => vec!["image_embeds", "last_hidden_state"], }; @@ -252,8 +255,12 @@ impl ImageEmbedding { // For 2D output [batch_size, hidden_size] output_array .outer_iter() - .map(|row| normalize(row.as_slice().unwrap())) - .collect() + .map(|row| { + row.as_slice() + .ok_or_else(|| anyhow!("Failed to convert array row to slice")) + .map(normalize) + }) + .collect::>>()? } _ => { return Err(anyhow!( diff --git a/src/image_embedding/utils.rs b/src/image_embedding/utils.rs index 7d1faf2..72e1c3d 100644 --- a/src/image_embedding/utils.rs +++ b/src/image_embedding/utils.rs @@ -142,17 +142,21 @@ impl Transform for Normalize { let array = data.array()?; let mean = Array::from_vec(self.mean.clone()) .into_shape_with_order((3, 1, 1)) - .unwrap(); + .map_err(|e| anyhow!("Failed to reshape mean array: {}", e))?; let std = Array::from_vec(self.std.clone()) .into_shape_with_order((3, 1, 1)) - .unwrap(); + .map_err(|e| anyhow!("Failed to reshape std array: {}", e))?; let shape = array.shape().to_vec(); match shape.as_slice() { [c, h, w] => { - let array_normalized = array - .sub(mean.broadcast((*c, *h, *w)).unwrap()) - .div(std.broadcast((*c, *h, *w)).unwrap()); + let mean_broadcast = mean.broadcast((*c, *h, *w)).ok_or_else(|| { + anyhow!("Failed to broadcast mean array to shape {:?}", (*c, *h, *w)) + })?; + let std_broadcast = std.broadcast((*c, *h, *w)).ok_or_else(|| { + anyhow!("Failed to broadcast std array to shape {:?}", (*c, *h, *w)) + })?; + let array_normalized = array.sub(mean_broadcast).div(std_broadcast); Ok(TransformData::NdArray(array_normalized)) } _ => Err(anyhow!( @@ -229,18 +233,21 @@ fn load_preprocessor(config: serde_json::Value) -> anyhow::Result { if config["do_center_crop"].as_bool().unwrap_or(false) { let crop_size = config["crop_size"].clone(); let (height, width) = if crop_size.is_u64() { - let size = crop_size.as_u64().unwrap() as u32; + let size = crop_size + .as_u64() + .ok_or_else(|| anyhow!("crop_size must be a valid u64"))? + as u32; (size, size) } else if crop_size.is_object() { ( crop_size["height"] .as_u64() .map(|height| height as u32) - .ok_or(anyhow!("crop_size height must be contained"))?, + .ok_or_else(|| anyhow!("crop_size height must be contained"))?, crop_size["width"] .as_u64() .map(|width| width as u32) - .ok_or(anyhow!("crop_size width must be contained"))?, + .ok_or_else(|| anyhow!("crop_size width must be contained"))?, ) } else { return Err(anyhow!("Invalid crop size: {:?}", crop_size)); @@ -304,18 +311,21 @@ fn load_preprocessor(config: serde_json::Value) -> anyhow::Result { if config["do_center_crop"].as_bool().unwrap_or(false) { let crop_size = config["crop_size"].clone(); let (height, width) = if crop_size.is_u64() { - let size = crop_size.as_u64().unwrap() as u32; + let size = crop_size + .as_u64() + .ok_or_else(|| anyhow!("crop_size must be a valid u64"))? + as u32; (size, size) } else if crop_size.is_object() { ( crop_size["height"] .as_u64() .map(|height| height as u32) - .ok_or(anyhow!("crop_size height must be contained"))?, + .ok_or_else(|| anyhow!("crop_size height must be contained"))?, crop_size["width"] .as_u64() .map(|width| width as u32) - .ok_or(anyhow!("crop_size width must be contained"))?, + .ok_or_else(|| anyhow!("crop_size width must be contained"))?, ) } else { return Err(anyhow!("Invalid crop size: {:?}", crop_size)); diff --git a/src/models/image_embedding.rs b/src/models/image_embedding.rs index 59a4fde..c071875 100644 --- a/src/models/image_embedding.rs +++ b/src/models/image_embedding.rs @@ -81,7 +81,7 @@ impl Display for ImageEmbeddingModel { let model_info = models_list() .into_iter() .find(|model| model.model == *self) - .unwrap(); + .ok_or(std::fmt::Error)?; write!(f, "{}", model_info.model_code) } } diff --git a/src/models/reranking.rs b/src/models/reranking.rs index 5b10da1..f63511f 100644 --- a/src/models/reranking.rs +++ b/src/models/reranking.rs @@ -54,7 +54,7 @@ impl Display for RerankerModel { let model_info = reranker_model_list() .into_iter() .find(|model| model.model == *self) - .expect("Model not found in supported models list."); + .ok_or(std::fmt::Error)?; write!(f, "{}", model_info.model_code) } } diff --git a/src/models/sparse.rs b/src/models/sparse.rs index 1b52d2f..f14502e 100644 --- a/src/models/sparse.rs +++ b/src/models/sparse.rs @@ -26,7 +26,7 @@ impl Display for SparseModel { let model_info = models_list() .into_iter() .find(|model| model.model == *self) - .unwrap(); + .ok_or(std::fmt::Error)?; write!(f, "{}", model_info.model_code) } } diff --git a/src/models/text_embedding.rs b/src/models/text_embedding.rs index 8292872..edca79f 100644 --- a/src/models/text_embedding.rs +++ b/src/models/text_embedding.rs @@ -402,7 +402,7 @@ impl ModelTrait for EmbeddingModel { impl Display for EmbeddingModel { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let model_info = EmbeddingModel::get_model_info(self).expect("Model not found."); + let model_info = EmbeddingModel::get_model_info(self).ok_or(std::fmt::Error)?; write!(f, "{}", model_info.model_code) } } diff --git a/src/reranking/impl.rs b/src/reranking/impl.rs index 0d8cae6..ce6dce9 100644 --- a/src/reranking/impl.rs +++ b/src/reranking/impl.rs @@ -42,7 +42,7 @@ impl TextRerank { TextRerank::list_supported_models() .into_iter() .find(|m| &m.model == model) - .expect("Model not found.") + .expect("Model not found in supported models list. This is a bug - please report it.") } pub fn list_supported_models() -> Vec { @@ -140,7 +140,10 @@ impl TextRerank { .encode_batch(inputs, true) .map_err(|e| anyhow::Error::msg(e.to_string()).context("Failed to encode batch"))?; - let encoding_length = encodings[0].len(); + let encoding_length = encodings + .first() + .ok_or_else(|| anyhow::anyhow!("Tokenizer returned empty encodings"))? + .len(); let batch_size = batch.len(); let max_size = encoding_length * batch_size; diff --git a/src/sparse_text_embedding/impl.rs b/src/sparse_text_embedding/impl.rs index 189010b..fabb1db 100644 --- a/src/sparse_text_embedding/impl.rs +++ b/src/sparse_text_embedding/impl.rs @@ -101,7 +101,7 @@ impl SparseTextEmbedding { SparseTextEmbedding::list_supported_models() .into_iter() .find(|m| &m.model == model) - .expect("Model not found.") + .expect("Model not found in supported models list. This is a bug - please report it.") } /// Method to generate sentence embeddings for a Vec of texts @@ -124,7 +124,10 @@ impl SparseTextEmbedding { })?; // Extract the encoding length and batch size - let encoding_length = encodings[0].len(); + let encoding_length = encodings + .first() + .ok_or_else(|| anyhow::anyhow!("Tokenizer returned empty encodings"))? + .len(); let batch_size = batch.len(); let max_size = encoding_length * batch_size; @@ -170,7 +173,10 @@ impl SparseTextEmbedding { // Try to get the only output key // If multiple, then default to `last_hidden_state` let last_hidden_state_key = match outputs.len() { - 1 => outputs.keys().next().unwrap(), + 1 => outputs + .keys() + .next() + .ok_or_else(|| anyhow::anyhow!("Expected one output but found none"))?, _ => "last_hidden_state", }; diff --git a/src/text_embedding/impl.rs b/src/text_embedding/impl.rs index 3a0fdc0..e83530d 100644 --- a/src/text_embedding/impl.rs +++ b/src/text_embedding/impl.rs @@ -300,7 +300,10 @@ impl TextEmbedding { })?; // Extract the encoding length and batch size - let encoding_length = encodings[0].len(); + let encoding_length = encodings + .first() + .ok_or_else(|| anyhow::anyhow!("Tokenizer returned empty encodings"))? + .len(); let batch_size = batch.len(); let max_size = encoding_length * batch_size; diff --git a/src/text_embedding/output.rs b/src/text_embedding/output.rs index 7f50e94..a7deec9 100644 --- a/src/text_embedding/output.rs +++ b/src/text_embedding/output.rs @@ -37,12 +37,18 @@ pub fn transformer_with_precedence( .map(|batch| { batch .select_and_pool_output(&output_precedence, pooling.clone()) - .map(|array| { + .and_then(|array| { array .rows() .into_iter() - .map(|row| normalize(row.as_slice().unwrap())) - .collect::>() + .map(|row| { + row.as_slice() + .ok_or_else(|| { + anyhow::anyhow!("Failed to convert array row to slice") + }) + .map(normalize) + }) + .collect::>>() }) }) .try_fold(Vec::new(), |mut acc, res| {