Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .releaserc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
]
}
],
"@semantic-release/release-notes-generator",
"@semantic-release/github",
[
"semantic-release-cargo",
Expand Down Expand Up @@ -108,10 +107,6 @@
"type": "feat",
"section": "🍕 Features"
},
{
"type": "feature",
"section": "🍕 Features"
},
{
"type": "fix",
"section": "🐛 Bug Fixes"
Expand Down
2 changes: 1 addition & 1 deletion src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ pub fn load_tokenizer(tokenizer_files: TokenizerFiles, max_length: usize) -> Res

let mut tokenizer = tokenizer
.with_padding(Some(PaddingParams {
// TODO: the user should able to choose the padding strategy
// TODO: the user should be able to choose the padding strategy
strategy: PaddingStrategy::BatchLongest,
pad_token,
pad_id,
Expand Down
8 changes: 4 additions & 4 deletions src/image_embedding/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ impl Transform for Normalize {
Ok(TransformData::NdArray(array_normalized))
}
_ => Err(anyhow!(
"Transformer convert error. Normlize operator get error shape."
"Transformer convert error. Normalize operator got error shape."
)),
}
}
Expand Down Expand Up @@ -236,11 +236,11 @@ fn load_preprocessor(config: serde_json::Value) -> anyhow::Result<Compose> {
crop_size["height"]
.as_u64()
.map(|height| height as u32)
.ok_or(anyhow!("crop_size height must be cotained"))?,
.ok_or(anyhow!("crop_size height must be contained"))?,
crop_size["width"]
.as_u64()
.map(|width| width as u32)
.ok_or(anyhow!("crop_size width must be cotained"))?,
.ok_or(anyhow!("crop_size width must be contained"))?,
)
} else {
return Err(anyhow!("Invalid crop size: {:?}", crop_size));
Expand Down Expand Up @@ -325,7 +325,7 @@ fn load_preprocessor(config: serde_json::Value) -> anyhow::Result<Compose> {
}));
}
}
mode => return Err(anyhow!("Preprocessror {} is not supported", mode)),
mode => return Err(anyhow!("Preprocessor {} is not supported", mode)),
}

transformers.push(Box::new(PILToNDarray));
Expand Down
10 changes: 5 additions & 5 deletions src/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ impl<M: Default> Default for InitOptions<M> {
}

impl<M: Default + HasMaxLength> InitOptionsWithLength<M> {
/// Crea a new InitOptionsWithLength with the given model name
/// Create a new InitOptionsWithLength with the given model name
pub fn new(model_name: M) -> Self {
Self {
model_name,
..Default::default()
}
}

/// Set the maximum maximum length
pub fn with_max_length(mut self, max_lenght: usize) -> Self {
self.max_length = max_lenght;
/// Set the maximum length
pub fn with_max_length(mut self, max_length: usize) -> Self {
self.max_length = max_length;
self
}

Expand All @@ -86,7 +86,7 @@ impl<M: Default + HasMaxLength> InitOptionsWithLength<M> {
}

impl<M: Default> InitOptions<M> {
/// Crea a new InitOptionsWithLength with the given model name
/// Create a new InitOptions with the given model name
pub fn new(model_name: M) -> Self {
Self {
model_name,
Expand Down
14 changes: 9 additions & 5 deletions src/reranking/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl TextRerank {
let api = ApiBuilder::from_cache(cache)
.with_progress(show_download_progress)
.build()
.expect("Failed to build API from cache");
.map_err(|e| anyhow::Error::msg(format!("Failed to build API from cache: {}", e)))?;
let model_repo = api.model(model_name.to_string());

let model_file_name = TextRerank::get_model_info(&model_name).model_file;
Expand Down Expand Up @@ -138,7 +138,7 @@ impl TextRerank {
let encodings = self
.tokenizer
.encode_batch(inputs, true)
.expect("Failed to encode batch");
.map_err(|e| anyhow::Error::msg(e.to_string()).context("Failed to encode batch"))?;

let encoding_length = encodings[0].len();
let batch_size = batch.len();
Expand Down Expand Up @@ -176,9 +176,13 @@ impl TextRerank {
}

let outputs = self.session.run(session_inputs)?;
let outputs = outputs["logits"]
let outputs = outputs
.get("logits")
.ok_or_else(|| anyhow::Error::msg("Output does not contain 'logits' key"))?
.try_extract_array()
.expect("Failed to extract logits tensor");
.map_err(|e| {
anyhow::Error::msg(format!("Failed to extract logits tensor: {}", e))
})?;
let batch_scores: Vec<f32> = outputs
.slice(s![.., 0])
.rows()
Expand All @@ -199,6 +203,6 @@ impl TextRerank {
})
.collect();
top_n_result.sort_by(|a, b| a.score.total_cmp(&b.score).reverse());
Ok(top_n_result.to_vec())
Ok(top_n_result)
}
}
2 changes: 1 addition & 1 deletion src/text_embedding/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub const OUTPUT_TYPE_PRECEDENCE: &[OutputKey] = &[
// OutputKey::ByName("token_embeddings"),
];

/// Generates thea default array transformer for the [`TextEmbedding`] model using the
/// Generates the default array transformer for the [`TextEmbedding`] model using the
/// provided output precedence.
///
// TODO (denwong47): now that pooling is done in SingleBatchOutput, it is possible that
Expand Down
Loading