From 4f5efaf153b3dc365794a6e085433831c1b5133a Mon Sep 17 00:00:00 2001 From: Angel Dijoux Date: Fri, 2 May 2025 18:31:25 +0200 Subject: [PATCH] feat: create models from model_code string --- src/models/image_embedding.rs | 22 +++++++++++++++++++++- src/models/reranking.rs | 22 +++++++++++++++++++++- src/models/sparse.rs | 22 +++++++++++++++++++++- src/models/text_embedding.rs | 22 +++++++++++++++++++++- 4 files changed, 84 insertions(+), 4 deletions(-) diff --git a/src/models/image_embedding.rs b/src/models/image_embedding.rs index 32931c1..185212c 100644 --- a/src/models/image_embedding.rs +++ b/src/models/image_embedding.rs @@ -1,4 +1,4 @@ -use std::fmt::Display; +use std::{fmt::Display, str::FromStr}; use super::model_info::ModelInfo; @@ -79,3 +79,23 @@ impl Display for ImageEmbeddingModel { write!(f, "{}", model_info.model_code) } } + +impl FromStr for ImageEmbeddingModel { + type Err = String; + + fn from_str(s: &str) -> Result { + models_list() + .into_iter() + .find(|m| m.model_code.eq_ignore_ascii_case(s)) + .map(|m| m.model) + .ok_or_else(|| format!("Unknown embedding model: {s}")) + } +} + +impl TryFrom for ImageEmbeddingModel { + type Error = String; + + fn try_from(value: String) -> Result { + value.parse() + } +} diff --git a/src/models/reranking.rs b/src/models/reranking.rs index 6a5dd31..5fb6b77 100644 --- a/src/models/reranking.rs +++ b/src/models/reranking.rs @@ -1,4 +1,4 @@ -use std::fmt::Display; +use std::{fmt::Display, str::FromStr}; use crate::RerankerModelInfo; @@ -57,3 +57,23 @@ impl Display for RerankerModel { write!(f, "{}", model_info.model_code) } } + +impl FromStr for RerankerModel { + type Err = String; + + fn from_str(s: &str) -> Result { + reranker_model_list() + .into_iter() + .find(|m| m.model_code.eq_ignore_ascii_case(s)) + .map(|m| m.model) + .ok_or_else(|| format!("Unknown reranker model: {s}")) + } +} + +impl TryFrom for RerankerModel { + type Error = String; + + fn try_from(value: String) -> Result { + value.parse() + } +} diff --git a/src/models/sparse.rs b/src/models/sparse.rs index 1e82ccf..f933a7a 100644 --- a/src/models/sparse.rs +++ b/src/models/sparse.rs @@ -1,4 +1,4 @@ -use std::fmt::Display; +use std::{fmt::Display, str::FromStr}; use crate::ModelInfo; @@ -28,3 +28,23 @@ impl Display for SparseModel { write!(f, "{}", model_info.model_code) } } + +impl FromStr for SparseModel { + type Err = String; + + fn from_str(s: &str) -> Result { + models_list() + .into_iter() + .find(|m| m.model_code.eq_ignore_ascii_case(s)) + .map(|m| m.model) + .ok_or_else(|| format!("Unknown sparse model: {s}")) + } +} + +impl TryFrom for SparseModel { + type Error = String; + + fn try_from(value: String) -> Result { + value.parse() + } +} diff --git a/src/models/text_embedding.rs b/src/models/text_embedding.rs index b9c4a2b..fa5b242 100644 --- a/src/models/text_embedding.rs +++ b/src/models/text_embedding.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, fmt::Display, sync::OnceLock}; +use std::{collections::HashMap, convert::TryFrom, fmt::Display, str::FromStr, sync::OnceLock}; use super::model_info::ModelInfo; @@ -360,3 +360,23 @@ impl Display for EmbeddingModel { write!(f, "{}", model_info.model_code) } } + +impl FromStr for EmbeddingModel { + type Err = String; + + fn from_str(s: &str) -> Result { + models_list() + .into_iter() + .find(|m| m.model_code.eq_ignore_ascii_case(s)) + .map(|m| m.model) + .ok_or_else(|| format!("Unknown embedding model: {s}")) + } +} + +impl TryFrom for EmbeddingModel { + type Error = String; + + fn try_from(value: String) -> Result { + value.parse() + } +}