Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/image_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl ImageEmbedding {
ImageEmbedding::list_supported_models()
.into_iter()
.find(|m| &m.model == model)
.expect("Model not found.")
.expect("Model not found in supported models list. This is a bug - please report it.")
}

/// Method to generate image embeddings for a Vec of image bytes
Expand Down Expand Up @@ -217,7 +217,10 @@ impl ImageEmbedding {
// Try to get the only output key
// If multiple, then default to few known keys `image_embeds` and `last_hidden_state`
let last_hidden_state_key = match outputs.len() {
1 => vec![outputs.keys().next().unwrap()],
1 => vec![outputs
.keys()
.next()
.ok_or_else(|| anyhow!("Expected one output but found none"))?],
_ => vec!["image_embeds", "last_hidden_state"],
};

Expand Down Expand Up @@ -252,8 +255,12 @@ impl ImageEmbedding {
// For 2D output [batch_size, hidden_size]
output_array
.outer_iter()
.map(|row| normalize(row.as_slice().unwrap()))
.collect()
.map(|row| {
row.as_slice()
.ok_or_else(|| anyhow!("Failed to convert array row to slice"))
.map(normalize)
})
.collect::<anyhow::Result<Vec<_>>>()?
}
_ => {
return Err(anyhow!(
Expand Down
32 changes: 21 additions & 11 deletions src/image_embedding/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,21 @@ impl Transform for Normalize {
let array = data.array()?;
let mean = Array::from_vec(self.mean.clone())
.into_shape_with_order((3, 1, 1))
.unwrap();
.map_err(|e| anyhow!("Failed to reshape mean array: {}", e))?;
let std = Array::from_vec(self.std.clone())
.into_shape_with_order((3, 1, 1))
.unwrap();
.map_err(|e| anyhow!("Failed to reshape std array: {}", e))?;

let shape = array.shape().to_vec();
match shape.as_slice() {
[c, h, w] => {
let array_normalized = array
.sub(mean.broadcast((*c, *h, *w)).unwrap())
.div(std.broadcast((*c, *h, *w)).unwrap());
let mean_broadcast = mean.broadcast((*c, *h, *w)).ok_or_else(|| {
anyhow!("Failed to broadcast mean array to shape {:?}", (*c, *h, *w))
})?;
let std_broadcast = std.broadcast((*c, *h, *w)).ok_or_else(|| {
anyhow!("Failed to broadcast std array to shape {:?}", (*c, *h, *w))
})?;
let array_normalized = array.sub(mean_broadcast).div(std_broadcast);
Ok(TransformData::NdArray(array_normalized))
}
_ => Err(anyhow!(
Expand Down Expand Up @@ -229,18 +233,21 @@ fn load_preprocessor(config: serde_json::Value) -> anyhow::Result<Compose> {
if config["do_center_crop"].as_bool().unwrap_or(false) {
let crop_size = config["crop_size"].clone();
let (height, width) = if crop_size.is_u64() {
let size = crop_size.as_u64().unwrap() as u32;
let size = crop_size
.as_u64()
.ok_or_else(|| anyhow!("crop_size must be a valid u64"))?
as u32;
(size, size)
} else if crop_size.is_object() {
(
crop_size["height"]
.as_u64()
.map(|height| height as u32)
.ok_or(anyhow!("crop_size height must be contained"))?,
.ok_or_else(|| anyhow!("crop_size height must be contained"))?,
crop_size["width"]
.as_u64()
.map(|width| width as u32)
.ok_or(anyhow!("crop_size width must be contained"))?,
.ok_or_else(|| anyhow!("crop_size width must be contained"))?,
)
} else {
return Err(anyhow!("Invalid crop size: {:?}", crop_size));
Expand Down Expand Up @@ -304,18 +311,21 @@ fn load_preprocessor(config: serde_json::Value) -> anyhow::Result<Compose> {
if config["do_center_crop"].as_bool().unwrap_or(false) {
let crop_size = config["crop_size"].clone();
let (height, width) = if crop_size.is_u64() {
let size = crop_size.as_u64().unwrap() as u32;
let size = crop_size
.as_u64()
.ok_or_else(|| anyhow!("crop_size must be a valid u64"))?
as u32;
(size, size)
} else if crop_size.is_object() {
(
crop_size["height"]
.as_u64()
.map(|height| height as u32)
.ok_or(anyhow!("crop_size height must be contained"))?,
.ok_or_else(|| anyhow!("crop_size height must be contained"))?,
crop_size["width"]
.as_u64()
.map(|width| width as u32)
.ok_or(anyhow!("crop_size width must be contained"))?,
.ok_or_else(|| anyhow!("crop_size width must be contained"))?,
)
} else {
return Err(anyhow!("Invalid crop size: {:?}", crop_size));
Expand Down
2 changes: 1 addition & 1 deletion src/models/image_embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl Display for ImageEmbeddingModel {
let model_info = models_list()
.into_iter()
.find(|model| model.model == *self)
.unwrap();
.ok_or(std::fmt::Error)?;
write!(f, "{}", model_info.model_code)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/models/reranking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl Display for RerankerModel {
let model_info = reranker_model_list()
.into_iter()
.find(|model| model.model == *self)
.expect("Model not found in supported models list.");
.ok_or(std::fmt::Error)?;
write!(f, "{}", model_info.model_code)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/models/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl Display for SparseModel {
let model_info = models_list()
.into_iter()
.find(|model| model.model == *self)
.unwrap();
.ok_or(std::fmt::Error)?;
write!(f, "{}", model_info.model_code)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/models/text_embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ impl ModelTrait for EmbeddingModel {

impl Display for EmbeddingModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let model_info = EmbeddingModel::get_model_info(self).expect("Model not found.");
let model_info = EmbeddingModel::get_model_info(self).ok_or(std::fmt::Error)?;
write!(f, "{}", model_info.model_code)
}
}
Expand Down
7 changes: 5 additions & 2 deletions src/reranking/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl TextRerank {
TextRerank::list_supported_models()
.into_iter()
.find(|m| &m.model == model)
.expect("Model not found.")
.expect("Model not found in supported models list. This is a bug - please report it.")
}

pub fn list_supported_models() -> Vec<RerankerModelInfo> {
Expand Down Expand Up @@ -140,7 +140,10 @@ impl TextRerank {
.encode_batch(inputs, true)
.map_err(|e| anyhow::Error::msg(e.to_string()).context("Failed to encode batch"))?;

let encoding_length = encodings[0].len();
let encoding_length = encodings
.first()
.ok_or_else(|| anyhow::anyhow!("Tokenizer returned empty encodings"))?
.len();
let batch_size = batch.len();
let max_size = encoding_length * batch_size;

Expand Down
12 changes: 9 additions & 3 deletions src/sparse_text_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl SparseTextEmbedding {
SparseTextEmbedding::list_supported_models()
.into_iter()
.find(|m| &m.model == model)
.expect("Model not found.")
.expect("Model not found in supported models list. This is a bug - please report it.")
}

/// Method to generate sentence embeddings for a Vec of texts
Expand All @@ -124,7 +124,10 @@ impl SparseTextEmbedding {
})?;

// Extract the encoding length and batch size
let encoding_length = encodings[0].len();
let encoding_length = encodings
.first()
.ok_or_else(|| anyhow::anyhow!("Tokenizer returned empty encodings"))?
.len();
let batch_size = batch.len();

let max_size = encoding_length * batch_size;
Expand Down Expand Up @@ -170,7 +173,10 @@ impl SparseTextEmbedding {
// Try to get the only output key
// If multiple, then default to `last_hidden_state`
let last_hidden_state_key = match outputs.len() {
1 => outputs.keys().next().unwrap(),
1 => outputs
.keys()
.next()
.ok_or_else(|| anyhow::anyhow!("Expected one output but found none"))?,
_ => "last_hidden_state",
};

Expand Down
5 changes: 4 additions & 1 deletion src/text_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,10 @@ impl TextEmbedding {
})?;

// Extract the encoding length and batch size
let encoding_length = encodings[0].len();
let encoding_length = encodings
.first()
.ok_or_else(|| anyhow::anyhow!("Tokenizer returned empty encodings"))?
.len();
let batch_size = batch.len();

let max_size = encoding_length * batch_size;
Expand Down
12 changes: 9 additions & 3 deletions src/text_embedding/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,18 @@ pub fn transformer_with_precedence(
.map(|batch| {
batch
.select_and_pool_output(&output_precedence, pooling.clone())
.map(|array| {
.and_then(|array| {
array
.rows()
.into_iter()
.map(|row| normalize(row.as_slice().unwrap()))
.collect::<Vec<Embedding>>()
.map(|row| {
row.as_slice()
.ok_or_else(|| {
anyhow::anyhow!("Failed to convert array row to slice")
})
.map(normalize)
})
.collect::<anyhow::Result<Vec<Embedding>>>()
})
})
.try_fold(Vec::new(), |mut acc, res| {
Expand Down
Loading