diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index 638c6514..2d545ba6 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -47,6 +47,7 @@ pub async fn run( watermark, grammar: String::new(), grammar_type: GrammarType::None as i32, + states_to_token_maps: None, }; // Initialize terminal properties diff --git a/proto/generate.proto b/proto/generate.proto index 6351e37f..11c3fef5 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -80,6 +80,18 @@ message NextTokenChooserParameters { string grammar = 10; /// grammar type GrammarType grammar_type = 11; + /// states to token maps + StatesToTokenMaps states_to_token_maps = 12; +} + +/// StatesToTokenMaps maps to a BTreeMap> in rust (start_state -> (token -> end_state)) +message StatesToTokenMaps { + /// Start state + repeated uint32 start_states = 1; + /// Token + repeated uint32 tokens = 2; + /// End state + repeated uint32 end_states = 3; } message StoppingCriteriaParameters { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index f8658318..ce04bcca 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -130,6 +130,7 @@ impl Client { watermark: true, grammar: String::new(), grammar_type: GrammarType::None as i32, + states_to_token_maps: None, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: max_total_tokens - truncate, diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 6782d9ff..d613362f 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -10,7 +10,7 @@ pub use pb::generate::v2::HealthResponse; pub use pb::generate::v2::InfoResponse as ShardInfo; pub use pb::generate::v2::{ Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, - NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, + NextTokenChooserParameters, Request, StatesToTokenMaps, StoppingCriteriaParameters, Tokens, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/src/health.rs b/router/src/health.rs index b05b3094..140367e8 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -48,6 +48,7 @@ impl Health { watermark: false, grammar: String::new(), grammar_type: ProtoGrammarType::None as i32, + states_to_token_maps: None, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/router/src/queue.rs b/router/src/queue.rs index 52ea16ca..dd3c7884 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -372,6 +372,7 @@ mod tests { watermark: false, grammar: String::new(), grammar_type: ProtoGrammarType::None as i32, + states_to_token_maps: None, }, stopping_parameters: StoppingCriteriaParameters { ignore_eos_token: false, diff --git a/router/src/validation.rs b/router/src/validation.rs index 5aa87c49..8ac86136 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -7,7 +7,8 @@ use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; use serde_json::Value; use text_generation_client::{ - GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, + GrammarType as ProtoGrammarType, NextTokenChooserParameters, StatesToTokenMaps, + StoppingCriteriaParameters, }; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; @@ -64,19 +65,36 @@ impl Validation { }); } - // TODO: start more than one grammar compilation worker - let (grammar_sender, grammar_receiver) = mpsc::unbounded_channel(); + // Create round robin channel + let (grammar_sender, grammar_round_robin_receiver) = mpsc::unbounded_channel(); + let mut grammar_senders = Vec::with_capacity(workers); - // create single grammar compilation workers - tokio::task::spawn_blocking(move || { - grammar_compilation_worker(tokenizer, grammar_receiver).map_err(|e| { - tracing::error!("Error in grammar compilation worker: {:?}", e); - e - }) - }); + // Create workers + for _ in 0..workers { + let tokenizer_clone = tokenizer.clone(); + let (grammar_sender, grammar_receiver) = mpsc::unbounded_channel(); + grammar_senders.push(grammar_sender); + + // Spawn worker + tokio::task::spawn_blocking(move || { + grammar_compilation_worker(tokenizer_clone, grammar_receiver).map_err(|e| { + tracing::error!("Error in grammar compilation worker: {:?}", e); + e + }) + }); + } // Create tokenization round robin task - tokio::spawn(round_robin_task(validation_round_robin_receiver, senders)); + tokio::spawn(round_robin_task::( + validation_round_robin_receiver, + senders, + )); + + // Create grammar compilation round robin task + tokio::spawn(round_robin_task::( + grammar_round_robin_receiver, + grammar_senders, + )); (Some(validation_sender), Some(grammar_sender)) } else { @@ -121,7 +139,10 @@ impl Validation { } #[instrument(skip(self, inputs))] - pub async fn compile_grammar(&self, inputs: String) -> Result { + pub async fn compile_grammar( + &self, + inputs: String, + ) -> Result<(String, StateTokenMaps), ValidationError> { // If we have a fast tokenizer if let Some(sender) = &self.grammar_compilation_sender { // Create response channel @@ -138,7 +159,7 @@ impl Validation { return Ok(encoding); } - Ok(inputs) + Ok((String::new(), BTreeMap::new())) } #[instrument(skip(self, inputs))] @@ -347,7 +368,7 @@ impl Validation { // compiler and use that to build the FSM here. // Validate grammar and unpack the grammar and type for the proto message - let (grammar, grammar_type) = match grammar { + let (grammar, grammar_type, states_to_token_maps) = match grammar { Some(grammar) => { // Ensure that grammar is not set if it's not supported if self.disable_grammar_support { @@ -371,16 +392,27 @@ impl Validation { .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; // NOTE: this is the first step to compile the grammar - let regex_compiled_grammar = self + let (regex_compiled_grammar, _states_to_token_maps) = self .compile_grammar(serde_json::to_string(&json).unwrap()) .await .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; - (regex_compiled_grammar, ProtoGrammarType::Regex.into()) + + let stm = StatesToTokenMaps { + start_states: vec![], + tokens: vec![], + end_states: vec![], + }; + + ( + regex_compiled_grammar, + ProtoGrammarType::Regex.into(), + Some(stm), + ) } - GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), + GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into(), None), } } - None => (String::new(), ProtoGrammarType::None.into()), + None => (String::new(), ProtoGrammarType::None.into(), None), }; let parameters = NextTokenChooserParameters { @@ -395,6 +427,7 @@ impl Validation { watermark, grammar, grammar_type, + states_to_token_maps, }; let stopping_parameters = StoppingCriteriaParameters { max_new_tokens, @@ -431,9 +464,9 @@ impl Validation { } /// Round robin tokenization task -async fn round_robin_task( - mut receiver: mpsc::UnboundedReceiver, - senders: Vec>, +async fn round_robin_task( + mut receiver: mpsc::UnboundedReceiver, + senders: Vec>, ) { loop { for sender in &senders { @@ -507,10 +540,15 @@ fn prepare_input( Ok((encoding, inputs)) } +type StateTokenMaps = BTreeMap>; + /// Compile a grammar -fn compile_grammar(inputs: String, tokenizer: &Tokenizer) -> Result { +fn compile_grammar( + inputs: String, + tokenizer: &Tokenizer, +) -> Result<(String, StateTokenMaps), ValidationError> { let start_time = std::time::Instant::now(); - let schema = Python::with_gil(|py| -> PyResult { + let (schema, states_to_token_maps) = Python::with_gil(|py| -> PyResult<(_, _)> { let fun: Py = PyModule::from_code( py, r#" @@ -555,7 +593,7 @@ def adapt_tokenizer(vocab, special_tokens): print(f"Adapted tokenizer in {time.time() - start_time:.2f}s") return tokenizer -def compile_regex_grammar(inputs, vocab): +def compile_regex_grammar(inputs, vocab, special_tokens): start_time = time.time() print("🔥 starting compile_regex_grammar", inputs) schema = build_regex_from_schema(inputs) @@ -576,48 +614,48 @@ def convert_grammar_to_regex(inputs): "", "", )? - // .getattr("compile_regex_grammar")? - .getattr("convert_grammar_to_regex")? - .into(); + .into_py(py); + + let convert_grammar_to_regex = fun.getattr(py, "convert_grammar_to_regex")?; + let compile_regex_grammar = fun.getattr(py, "compile_regex_grammar")?; let args: &pyo3::types::PyDict = tokenizer.get_vocab(true).into_py_dict(py); let special_tokens: Vec = vec![]; - let regex_fsm = fun.call(py, (inputs.clone(),), None)?; + let regex_fsm = convert_grammar_to_regex.call(py, (inputs.clone(),), None)?; - if false { - let regex_fsm = fun.call(py, (inputs.clone(), args, special_tokens), None)?; - let regex_fsm_ref = regex_fsm.into_ref(py); + let compiled_grammar = + compile_regex_grammar.call(py, (inputs.clone(), args, special_tokens), None)?; + let compiled_grammar_ref = compiled_grammar.into_ref(py); - let states_to_token_maps = regex_fsm_ref - .getattr("states_to_token_maps")? - .extract::>>()?; + let states_to_token_maps = compiled_grammar_ref + .getattr("states_to_token_maps")? + .extract::()?; - println!("🔥 elapsed: {:?}", start_time.elapsed()); + println!("🔥 elapsed: {:?}", start_time.elapsed()); - // size of serialized states_to_token_maps - let states_to_token_maps_json = serde_json::to_string(&states_to_token_maps).unwrap(); - println!( - "🔥 states_to_token_maps size: {:.2}MB", - states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0 - ); - } + // size of serialized states_to_token_maps + let states_to_token_maps_json = serde_json::to_string(&states_to_token_maps).unwrap(); + println!( + "🔥 states_to_token_maps size: {:.2}MB", + states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0 + ); let result = regex_fsm.into_ref(py).extract().unwrap(); println!("result: {:?}", result); - Ok(result) + Ok((result, states_to_token_maps)) }) .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; let elapsed = start_time.elapsed(); println!("🔥 elapsed: {:?}", elapsed); - Ok(schema) + Ok((schema, states_to_token_maps)) } type GrammarCompilationRequest = ( String, - oneshot::Sender>, + oneshot::Sender>, Span, );