Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
7fa4686
import llguidance modules
mmoskal Nov 6, 2024
1dcca91
llg constraint types
mmoskal Nov 6, 2024
eaf52be
integrate llguidance
mmoskal Nov 6, 2024
cc82722
fix handling of stop tokens
mmoskal Nov 6, 2024
fb83b5c
update toktrie
mmoskal Nov 6, 2024
00c36dd
remove submodules
mmoskal Nov 6, 2024
cab564b
fix version conflicts
mmoskal Nov 6, 2024
d1becbd
tok_trie -> tok_env rename
mmoskal Nov 6, 2024
f70b3de
update to latest llguidance
mmoskal Nov 30, 2024
3d22f2b
Merge branch 'master' into llg_cleanup
mmoskal Nov 30, 2024
ca9e346
sync lock
mmoskal Nov 30, 2024
7b3ae50
bump llg (lazy_static fix)
mmoskal Nov 30, 2024
146bc4c
update to latest llguidance, fix conflicts
mmoskal Nov 30, 2024
6019fee
import toktrie via llguidance
mmoskal Nov 30, 2024
dd35965
n=1
mmoskal Dec 1, 2024
3ac55ed
test with llama1b
mmoskal Dec 1, 2024
8ca7514
remove aici folder (no longer used)
mmoskal Dec 1, 2024
7967d42
use more specific type for llg grammars
mmoskal Dec 1, 2024
9a919d8
update python APIs to support json schema and llg
mmoskal Dec 1, 2024
0ad0948
update example to use lark not yacc
mmoskal Dec 1, 2024
816ac8f
rename example
mmoskal Dec 1, 2024
5fc906b
remove testing scripts
mmoskal Dec 1, 2024
5e9cbd2
re-export llguidance for easier LlguidanceGrammar construction
mmoskal Dec 2, 2024
ffcdd2a
fix formatting
mmoskal Dec 2, 2024
fda20fe
fix clippy
mmoskal Dec 2, 2024
b5add20
Merge branch 'master' into llg_cleanup
mmoskal Dec 2, 2024
2c59224
add python samples
mmoskal Dec 2, 2024
976092c
add server samples
mmoskal Dec 2, 2024
ac7c35d
add rust samples
mmoskal Dec 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
integrate llguidance
  • Loading branch information
mmoskal committed Nov 6, 2024
commit eaf52be02e0a4ed8c992acc8b4b5ebbb1e33ddbc
455 changes: 450 additions & 5 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ members = [
"mistralrs-bench",
"mistralrs-vision",
"mistralrs-quant",
"external/toktrie/core",
"external/toktrie/hf_tokenizers",
"external/llguidance/parser",
]
exclude = [
"mistralrs-paged_attn",
Expand Down Expand Up @@ -50,3 +53,7 @@ url = "2.5.2"
data-url = "0.3.1"
buildstructor = "0.5.4"
float8 = "0.1.1"

[patch.'https://github.com/microsoft/toktrie']
toktrie = { path = "external/toktrie/core" }
toktrie_hf_tokenizers = { path = "external/toktrie/hf_tokenizers" }
2 changes: 1 addition & 1 deletion external/toktrie
3 changes: 3 additions & 0 deletions mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ safetensors = "0.4.5"
serde_plain = "1.0.2"
as-any = "0.3.1"
float8.workspace = true
llguidance_parser = { path = "../external/llguidance/parser" }
toktrie = { path = "../external/toktrie/core" }
toktrie_hf_tokenizers = { path = "../external/toktrie/hf_tokenizers" }

[features]
pyo3_macros = ["pyo3"]
Expand Down
43 changes: 24 additions & 19 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ use std::{
time::{Instant, SystemTime, UNIX_EPOCH},
};
use tokio::sync::{mpsc::Receiver, Mutex};
use toktrie::TokEnv;

use crate::{
aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx},
pipeline::{
text_models_inputs_processor::PagedAttentionMeta, AdapterInstruction, CacheBackendMetadata,
CacheInstruction,
llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
text_models_inputs_processor::PagedAttentionMeta,
AdapterInstruction, CacheBackendMetadata, CacheInstruction,
},
request::NormalRequest,
response::CompletionChoice,
Expand Down Expand Up @@ -455,15 +456,19 @@ impl Engine {
}
}

fn build_sequence_recognizer(constraint: &Constraint) -> anyhow::Result<SequenceRecognizer> {
let recognizer = match constraint {
Constraint::Regex(rx) => {
SequenceRecognizer::Regex(StackRecognizer::from(RecRx::from_rx(rx, None)?).into())
}
Constraint::Yacc(cfg) => SequenceRecognizer::Cfg(CfgParser::from_yacc(cfg)?.into()),
Constraint::None => SequenceRecognizer::None,
};
Ok(recognizer)
fn build_sequence_recognizer(
tok_env: &Option<TokEnv>,
constraint: &Constraint,
) -> anyhow::Result<SequenceRecognizer> {
if let Some(grm) = llg_grammar_from_constraint(constraint)? {
let tok_env = tok_env
.as_ref()
.ok_or_else(|| anyhow::anyhow!("No token environment found."))?;
let llg = constraint_from_llg_grammar(tok_env.clone(), grm)?;
Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
} else {
Ok(SequenceRecognizer::None)
}
}

async fn handle_request(&mut self, request: Request) {
Expand Down Expand Up @@ -668,6 +673,7 @@ impl Engine {
for id in i {
// We can't use ` ` (space) as a stop token because other tokens like ` moon` start with a space.
if let Some(tok_trie) = tok_trie.as_ref() {
let tok_trie = tok_trie.tok_trie();
if tok_trie.has_extensions(tok_trie.token(*id)) {
request
.response
Expand Down Expand Up @@ -712,6 +718,7 @@ impl Engine {

if toks.len() == 1 {
if tok_trie.as_ref().is_some_and(|tok_trie| {
let tok_trie = tok_trie.tok_trie();
tok_trie.has_extensions(tok_trie.token(toks[0]))
}) {
stop_strings.push(stop_txt.clone());
Expand Down Expand Up @@ -766,7 +773,11 @@ impl Engine {

// Add sequences
for response_index in 0..request.sampling_params.n_choices {
let recognizer = match Self::build_sequence_recognizer(&request.constraint) {
let trie = get_mut_arcmutex!(self.pipeline)
.get_metadata()
.tok_trie
.clone();
let recognizer = match Self::build_sequence_recognizer(&trie, &request.constraint) {
Ok(recognizer) => recognizer,
Err(err) => {
request
Expand All @@ -785,11 +796,6 @@ impl Engine {
.cache_config
.clone()
.map(|conf| conf.block_size);
let trie = get_mut_arcmutex!(self.pipeline)
.get_metadata()
.tok_trie
.as_ref()
.map(|x| (**x).clone());
let seq = Sequence::new_waiting(
prompt_tokens.clone(),
prompt_text.clone(),
Expand All @@ -816,7 +822,6 @@ impl Engine {
request.adapters.clone(),
images.clone(),
block_size,
trie,
matcher.clone(),
image_generation_format,
seq_step_type,
Expand Down
1 change: 0 additions & 1 deletion mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use std::{
};
use tokio::sync::mpsc::{channel, Sender};

mod aici;
mod cuda;
mod device_map;
mod engine;
Expand Down
1 change: 0 additions & 1 deletion mistralrs-core/src/pipeline/amoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,6 @@ fn new_dummy_seq(
None, // TODO incorrect for PagedAttention
None,
None,
None,
SeqStepType::PromptAndDecode,
None,
)
Expand Down
5 changes: 2 additions & 3 deletions mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ use super::{
AdapterActivationMixin, AnyMoePipelineMixin, CacheManagerMixin, ForwardInputsResult,
IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
};
use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
use crate::lora::Ordering;
use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
use super::llg::build_tok_env;
use crate::pipeline::sampling::sample_and_add_toks;
use crate::pipeline::{get_chat_template, Cache};
use crate::pipeline::{ChatTemplate, LocalModelPaths};
Expand Down Expand Up @@ -356,7 +355,7 @@ impl Loader for GGMLLoader {
Model::Llama(ref l) => l.max_seq_len,
Model::XLoraLlama(ref xl) => xl.max_seq_len,
};
let tok_trie: Arc<TokTrie> = build_tok_trie(tokenizer.clone()).into();
let tok_trie = build_tok_env(tokenizer.clone()).into();
let num_hidden_layers = match model {
Model::Llama(ref model) => model.cache.lock().len(),
Model::XLoraLlama(ref model) => model.cache.lock().len(),
Expand Down
5 changes: 2 additions & 3 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::cache_manager::DefaultCacheManager;
use super::llg::build_tok_env;
use super::{
get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, PrettyName, QuantizationKind,
Expand All @@ -8,8 +9,6 @@ use super::{
AdapterActivationMixin, AnyMoePipelineMixin, CacheManagerMixin, ForwardInputsResult,
IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
};
use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
use crate::gguf::{
get_gguf_chat_template, {convert_gguf_to_hf_tokenizer, GgufTokenizerConversion},
};
Expand Down Expand Up @@ -484,7 +483,7 @@ impl Loader for GGUFLoader {
Model::Starcoder2(ref p) => p.max_seq_len,
Model::Qwen2(ref p) => p.max_seq_len,
};
let tok_trie: Arc<TokTrie> = build_tok_trie(tokenizer.clone()).into();
let tok_trie = build_tok_env(tokenizer.clone()).into();
let num_hidden_layers = match model {
Model::Llama(ref model) => model.cache.lock().len(),
Model::Phi2(ref model) => model.cache.lock().len(),
Expand Down
52 changes: 52 additions & 0 deletions mistralrs-core/src/pipeline/llg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use std::sync::Arc;

use anyhow::Result;
use llguidance_parser::{
api::{ParserLimits, RegexNode, TopLevelGrammar},
lark::{lark_to_llguidance, parse_lark},
JsonCompileOptions, TokenParser,
};
use tokenizers::Tokenizer;
use toktrie::{InferenceCapabilities, TokEnv};

use crate::Constraint;

pub fn build_tok_env(tokenizer: Tokenizer) -> TokEnv {
let bt = toktrie_hf_tokenizers::ByteTokenizer::from_tokenizer(tokenizer)
.expect("Failed to create ByteTokenizer from Tokenizer");
let env = toktrie_hf_tokenizers::ByteTokenizerEnv::new(bt, None)
.expect("Failed to create ByteTokenizerEnv");
Arc::new(env)
}

pub fn llg_grammar_from_constraint(constraint: &Constraint) -> Result<Option<TopLevelGrammar>> {
let grm = match constraint {
Constraint::Regex(regex) => {
TopLevelGrammar::from_regex(RegexNode::Regex(regex.to_string()))
}
Constraint::Lark(lark) => lark_to_llguidance(parse_lark(lark)?)?,
Constraint::JsonSchema(value) => {
JsonCompileOptions::default().json_to_llg_no_validate(value)?
}
Constraint::Llguidance(value) => serde_json::from_value(value.clone())?,
Constraint::None => return Ok(None),
};
Ok(Some(grm))
}

pub fn constraint_from_llg_grammar(
tok_env: TokEnv,
grm: TopLevelGrammar,
) -> Result<llguidance_parser::Constraint> {
let parser = TokenParser::from_llguidance_json(
tok_env,
grm,
llguidance_parser::Logger::new(0, 1),
InferenceCapabilities {
..Default::default()
},
ParserLimits::default(),
vec![],
)?;
Ok(llguidance_parser::Constraint::new(parser))
}
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ mod processing;
mod sampling;
mod speculative;
mod vision;
pub(crate) mod llg;

pub use super::diffusion_models::DiffusionGenerationParams;
use crate::aici::toktree::TokTrie;
use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
use crate::diffusion_models::response::send_responses;
use crate::paged_attention::{CacheConfig, CacheEngine};
Expand Down Expand Up @@ -66,7 +66,7 @@ use self::text_models_inputs_processor::PagedAttentionMeta;
pub struct GeneralMetadata {
pub max_seq_len: usize,
/// Only None if it doesnt make sense for the model
pub tok_trie: Option<Arc<TokTrie>>,
pub tok_trie: Option<toktrie::TokEnv>,
pub has_no_kv_cache: bool,
pub num_hidden_layers: usize,
pub eos_tok: Vec<u32>,
Expand Down
5 changes: 2 additions & 3 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ use super::{
AutoLoader, Gemma2Loader, GemmaLoader, LlamaLoader, MistralLoader, MixtralLoader,
NormalLoaderType, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Starcoder2Loader,
};
use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
use super::llg::build_tok_env;
use crate::amoe::AnyMoeExpertType;
use crate::lora::Ordering;
use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
Expand Down Expand Up @@ -426,7 +425,7 @@ impl Loader for NormalLoader {
};

let max_seq_len = model.max_seq_len();
let tok_trie: Arc<TokTrie> = build_tok_trie(tokenizer.clone()).into();
let tok_trie = build_tok_env(tokenizer.clone()).into();
let num_hidden_layers = model.cache().lock().len();
let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
let sliding_window = model.config().sliding_window;
Expand Down
44 changes: 16 additions & 28 deletions mistralrs-core/src/pipeline/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use candle_core::{DType, Device, Result, Tensor};
use rand_isaac::Isaac64Rng;

use crate::{
get_bias_if_not_allowed,
prefix_cacher::PrefixCacheManager,
sampler::Logprobs,
sequence::{Sequence, SequenceRecognizer},
Expand All @@ -30,6 +29,7 @@ pub(crate) async fn finish_or_add_toks_to_seq(
"`finish_or_add_toks_to_seq` requires the pipeline to have a token trie"
.to_string(),
))?
.tok_trie()
.decode(&[logprobs.token]),
&is_done,
);
Expand Down Expand Up @@ -325,25 +325,23 @@ pub async fn sample_sequence(
};

let bias_if_not_allowed = match &mut seq.recognizer {
SequenceRecognizer::Regex(ref mut rx) => {
get_bias_if_not_allowed!(seq.tok_trie, rx.as_mut(), first_lobprobs_response.token)
}
SequenceRecognizer::Cfg(ref mut cfg) => {
get_bias_if_not_allowed!(seq.tok_trie, cfg.as_mut(), first_lobprobs_response.token)
SequenceRecognizer::Llguidance(ref mut llg) => {
let bias = llg.compute_mask().map_err(candle_core::Error::msg)?;
if let Some(mask) = &bias.sample_mask {
if mask.is_allowed(first_lobprobs_response.token) {
None
} else {
Some(mask)
}
} else {
None
}
}
SequenceRecognizer::None => None,
};
let second_logprobs_response = match bias_if_not_allowed {
Some(token_set) => {
let mut acc = vec![
-f32::INFINITY;
seq.tok_trie
.as_ref()
.ok_or(candle_core::Error::Msg(
"TokTrie must be present in pipeline if bias is calculated".to_string()
))?
.vocab_size()
];
let mut acc = vec![-f32::INFINITY; token_set.len()];
token_set.apply_to(&mut acc);
let new_logits = (logits + Tensor::from_slice(&acc, acc.len(), &Device::Cpu)?)?;

Expand Down Expand Up @@ -374,20 +372,10 @@ pub async fn sample_sequence(
None => first_lobprobs_response,
};

if add_to_trie && seq.tok_trie.is_some() {
if add_to_trie {
match seq.recognizer {
SequenceRecognizer::Regex(ref mut rx) => {
seq.tok_trie
.as_ref()
.unwrap()
.append_token(rx.as_mut(), second_logprobs_response.token)
.map_err(candle_core::Error::msg)?;
}
SequenceRecognizer::Cfg(ref mut cfg) => {
seq.tok_trie
.as_ref()
.unwrap()
.append_token(cfg.as_mut(), second_logprobs_response.token)
SequenceRecognizer::Llguidance(ref mut llg) => {
llg.commit_token(Some(second_logprobs_response.token))
.map_err(candle_core::Error::msg)?;
}
SequenceRecognizer::None => {}
Expand Down
22 changes: 2 additions & 20 deletions mistralrs-core/src/pipeline/speculative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,26 +579,8 @@ impl Pipeline for SpeculativePipeline {
)
.await?;
match seq.recognizer {
SequenceRecognizer::Regex(ref mut rx) => {
get_mut_arcmutex!(self.target)
.get_metadata()
.tok_trie
.as_ref()
.ok_or(candle_core::Error::Msg(
"`SpeculativePipeline::step` requires a token trie".to_string(),
))?
.append_token(rx.as_mut(), accepted.token)
.map_err(candle_core::Error::msg)?;
}
SequenceRecognizer::Cfg(ref mut cfg) => {
get_mut_arcmutex!(self.target)
.get_metadata()
.tok_trie
.as_ref()
.ok_or(candle_core::Error::Msg(
"`SpeculativePipeline::step` requires a token trie".to_string(),
))?
.append_token(cfg.as_mut(), accepted.token)
SequenceRecognizer::Llguidance(ref mut llg) => {
llg.commit_token(Some(accepted.token))
.map_err(candle_core::Error::msg)?;
}
SequenceRecognizer::None => {}
Expand Down
Loading