diff --git a/src/common.rs b/src/common.rs index 5023c94..ba820c3 100644 --- a/src/common.rs +++ b/src/common.rs @@ -112,20 +112,36 @@ pub fn load_tokenizer(tokenizer_files: TokenizerFiles, max_length: usize) -> Res if let serde_json::Value::Object(root_object) = special_tokens_map { for (_, value) in root_object.iter() { if value.is_string() { - tokenizer.add_special_tokens(&[AddedToken { - content: value.as_str().unwrap().into(), - special: true, - ..Default::default() - }]); + if let Some(content) = value.as_str() { + tokenizer.add_special_tokens(&[AddedToken { + content: content.into(), + special: true, + ..Default::default() + }]); + } } else if value.is_object() { - tokenizer.add_special_tokens(&[AddedToken { - content: value["content"].as_str().unwrap().into(), - special: true, - single_word: value["single_word"].as_bool().unwrap(), - lstrip: value["lstrip"].as_bool().unwrap(), - rstrip: value["rstrip"].as_bool().unwrap(), - normalized: value["normalized"].as_bool().unwrap(), - }]); + if let ( + Some(content), + Some(single_word), + Some(lstrip), + Some(rstrip), + Some(normalized), + ) = ( + value["content"].as_str(), + value["single_word"].as_bool(), + value["lstrip"].as_bool(), + value["rstrip"].as_bool(), + value["normalized"].as_bool(), + ) { + tokenizer.add_special_tokens(&[AddedToken { + content: content.into(), + special: true, + single_word, + lstrip, + rstrip, + normalized, + }]); + } } } } diff --git a/src/pooling.rs b/src/pooling.rs index cd0e6b3..9169f61 100644 --- a/src/pooling.rs +++ b/src/pooling.rs @@ -59,13 +59,13 @@ pub fn mean( let attention_mask = attention_mask_array .insert_axis(ndarray::Axis(2)) .broadcast(token_embeddings.dim()) - .unwrap_or_else(|| { - panic!( + .ok_or_else(|| { + anyhow::Error::msg(format!( "Could not broadcast attention mask from {:?} to {:?}", attention_mask_original_dim, token_embeddings.dim() - ) - }) + )) + })? .mapv(|x| x as f32); let masked_tensor = &attention_mask * &token_embeddings; diff --git a/src/sparse_text_embedding/impl.rs b/src/sparse_text_embedding/impl.rs index f8c29db..189010b 100644 --- a/src/sparse_text_embedding/impl.rs +++ b/src/sparse_text_embedding/impl.rs @@ -119,7 +119,9 @@ impl SparseTextEmbedding { .map(|batch| { // Encode the texts in the batch let inputs = batch.iter().map(|text| text.as_ref()).collect(); - let encodings = self.tokenizer.encode_batch(inputs, true).unwrap(); + let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| { + anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.") + })?; // Extract the encoding length and batch size let encoding_length = encodings[0].len();