feat: compile grammar and send over grpc

This commit is contained in:
drbh 2024-03-06 02:40:56 +00:00
parent ad5f562aa5
commit 4f7074ca71
7 changed files with 100 additions and 46 deletions

View File

@ -47,6 +47,7 @@ pub async fn run(
watermark, watermark,
grammar: String::new(), grammar: String::new(),
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
states_to_token_maps: None,
}; };
// Initialize terminal properties // Initialize terminal properties

View File

@ -80,6 +80,18 @@ message NextTokenChooserParameters {
string grammar = 10; string grammar = 10;
/// grammar type /// grammar type
GrammarType grammar_type = 11; GrammarType grammar_type = 11;
/// states to token maps
StatesToTokenMaps states_to_token_maps = 12;
}
/// StatesToTokenMaps maps to a BTreeMap<u32, BTreeMap<u32, u32>> in rust (start_state -> (token -> end_state))
message StatesToTokenMaps {
/// Start state
repeated uint32 start_states = 1;
/// Token
repeated uint32 tokens = 2;
/// End state
repeated uint32 end_states = 3;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {

View File

@ -130,6 +130,7 @@ impl Client {
watermark: true, watermark: true,
grammar: String::new(), grammar: String::new(),
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
states_to_token_maps: None,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate, max_new_tokens: max_total_tokens - truncate,

View File

@ -10,7 +10,7 @@ pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo; pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{ pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, NextTokenChooserParameters, Request, StatesToTokenMaps, StoppingCriteriaParameters, Tokens,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;

View File

@ -48,6 +48,7 @@ impl Health {
watermark: false, watermark: false,
grammar: String::new(), grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32, grammar_type: ProtoGrammarType::None as i32,
states_to_token_maps: None,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1, max_new_tokens: 1,

View File

@ -372,6 +372,7 @@ mod tests {
watermark: false, watermark: false,
grammar: String::new(), grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32, grammar_type: ProtoGrammarType::None as i32,
states_to_token_maps: None,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false, ignore_eos_token: false,

View File

@ -7,7 +7,8 @@ use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use text_generation_client::{ use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, GrammarType as ProtoGrammarType, NextTokenChooserParameters, StatesToTokenMaps,
StoppingCriteriaParameters,
}; };
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
@ -64,19 +65,36 @@ impl Validation {
}); });
} }
// TODO: start more than one grammar compilation worker // Create round robin channel
let (grammar_sender, grammar_receiver) = mpsc::unbounded_channel(); let (grammar_sender, grammar_round_robin_receiver) = mpsc::unbounded_channel();
let mut grammar_senders = Vec::with_capacity(workers);
// create single grammar compilation workers // Create workers
tokio::task::spawn_blocking(move || { for _ in 0..workers {
grammar_compilation_worker(tokenizer, grammar_receiver).map_err(|e| { let tokenizer_clone = tokenizer.clone();
tracing::error!("Error in grammar compilation worker: {:?}", e); let (grammar_sender, grammar_receiver) = mpsc::unbounded_channel();
e grammar_senders.push(grammar_sender);
})
}); // Spawn worker
tokio::task::spawn_blocking(move || {
grammar_compilation_worker(tokenizer_clone, grammar_receiver).map_err(|e| {
tracing::error!("Error in grammar compilation worker: {:?}", e);
e
})
});
}
// Create tokenization round robin task // Create tokenization round robin task
tokio::spawn(round_robin_task(validation_round_robin_receiver, senders)); tokio::spawn(round_robin_task::<TokenizerRequest>(
validation_round_robin_receiver,
senders,
));
// Create grammar compilation round robin task
tokio::spawn(round_robin_task::<GrammarCompilationRequest>(
grammar_round_robin_receiver,
grammar_senders,
));
(Some(validation_sender), Some(grammar_sender)) (Some(validation_sender), Some(grammar_sender))
} else { } else {
@ -121,7 +139,10 @@ impl Validation {
} }
#[instrument(skip(self, inputs))] #[instrument(skip(self, inputs))]
pub async fn compile_grammar(&self, inputs: String) -> Result<String, ValidationError> { pub async fn compile_grammar(
&self,
inputs: String,
) -> Result<(String, StateTokenMaps), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some(sender) = &self.grammar_compilation_sender { if let Some(sender) = &self.grammar_compilation_sender {
// Create response channel // Create response channel
@ -138,7 +159,7 @@ impl Validation {
return Ok(encoding); return Ok(encoding);
} }
Ok(inputs) Ok((String::new(), BTreeMap::new()))
} }
#[instrument(skip(self, inputs))] #[instrument(skip(self, inputs))]
@ -347,7 +368,7 @@ impl Validation {
// compiler and use that to build the FSM here. // compiler and use that to build the FSM here.
// Validate grammar and unpack the grammar and type for the proto message // Validate grammar and unpack the grammar and type for the proto message
let (grammar, grammar_type) = match grammar { let (grammar, grammar_type, states_to_token_maps) = match grammar {
Some(grammar) => { Some(grammar) => {
// Ensure that grammar is not set if it's not supported // Ensure that grammar is not set if it's not supported
if self.disable_grammar_support { if self.disable_grammar_support {
@ -371,16 +392,27 @@ impl Validation {
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
// NOTE: this is the first step to compile the grammar // NOTE: this is the first step to compile the grammar
let regex_compiled_grammar = self let (regex_compiled_grammar, _states_to_token_maps) = self
.compile_grammar(serde_json::to_string(&json).unwrap()) .compile_grammar(serde_json::to_string(&json).unwrap())
.await .await
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
(regex_compiled_grammar, ProtoGrammarType::Regex.into())
let stm = StatesToTokenMaps {
start_states: vec![],
tokens: vec![],
end_states: vec![],
};
(
regex_compiled_grammar,
ProtoGrammarType::Regex.into(),
Some(stm),
)
} }
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into(), None),
} }
} }
None => (String::new(), ProtoGrammarType::None.into()), None => (String::new(), ProtoGrammarType::None.into(), None),
}; };
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
@ -395,6 +427,7 @@ impl Validation {
watermark, watermark,
grammar, grammar,
grammar_type, grammar_type,
states_to_token_maps,
}; };
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,
@ -431,9 +464,9 @@ impl Validation {
} }
/// Round robin tokenization task /// Round robin tokenization task
async fn round_robin_task( async fn round_robin_task<T>(
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>, mut receiver: mpsc::UnboundedReceiver<T>,
senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>, senders: Vec<mpsc::UnboundedSender<T>>,
) { ) {
loop { loop {
for sender in &senders { for sender in &senders {
@ -507,10 +540,15 @@ fn prepare_input(
Ok((encoding, inputs)) Ok((encoding, inputs))
} }
type StateTokenMaps = BTreeMap<u32, BTreeMap<u32, u32>>;
/// Compile a grammar /// Compile a grammar
fn compile_grammar(inputs: String, tokenizer: &Tokenizer) -> Result<String, ValidationError> { fn compile_grammar(
inputs: String,
tokenizer: &Tokenizer,
) -> Result<(String, StateTokenMaps), ValidationError> {
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
let schema = Python::with_gil(|py| -> PyResult<String> { let (schema, states_to_token_maps) = Python::with_gil(|py| -> PyResult<(_, _)> {
let fun: Py<PyAny> = PyModule::from_code( let fun: Py<PyAny> = PyModule::from_code(
py, py,
r#" r#"
@ -555,7 +593,7 @@ def adapt_tokenizer(vocab, special_tokens):
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s") print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
return tokenizer return tokenizer
def compile_regex_grammar(inputs, vocab): def compile_regex_grammar(inputs, vocab, special_tokens):
start_time = time.time() start_time = time.time()
print("🔥 starting compile_regex_grammar", inputs) print("🔥 starting compile_regex_grammar", inputs)
schema = build_regex_from_schema(inputs) schema = build_regex_from_schema(inputs)
@ -576,48 +614,48 @@ def convert_grammar_to_regex(inputs):
"", "",
"", "",
)? )?
// .getattr("compile_regex_grammar")? .into_py(py);
.getattr("convert_grammar_to_regex")?
.into(); let convert_grammar_to_regex = fun.getattr(py, "convert_grammar_to_regex")?;
let compile_regex_grammar = fun.getattr(py, "compile_regex_grammar")?;
let args: &pyo3::types::PyDict = tokenizer.get_vocab(true).into_py_dict(py); let args: &pyo3::types::PyDict = tokenizer.get_vocab(true).into_py_dict(py);
let special_tokens: Vec<String> = vec![]; let special_tokens: Vec<String> = vec![];
let regex_fsm = fun.call(py, (inputs.clone(),), None)?; let regex_fsm = convert_grammar_to_regex.call(py, (inputs.clone(),), None)?;
if false { let compiled_grammar =
let regex_fsm = fun.call(py, (inputs.clone(), args, special_tokens), None)?; compile_regex_grammar.call(py, (inputs.clone(), args, special_tokens), None)?;
let regex_fsm_ref = regex_fsm.into_ref(py); let compiled_grammar_ref = compiled_grammar.into_ref(py);
let states_to_token_maps = regex_fsm_ref let states_to_token_maps = compiled_grammar_ref
.getattr("states_to_token_maps")? .getattr("states_to_token_maps")?
.extract::<BTreeMap<u32, BTreeMap<u32, u32>>>()?; .extract::<StateTokenMaps>()?;
println!("🔥 elapsed: {:?}", start_time.elapsed()); println!("🔥 elapsed: {:?}", start_time.elapsed());
// size of serialized states_to_token_maps // size of serialized states_to_token_maps
let states_to_token_maps_json = serde_json::to_string(&states_to_token_maps).unwrap(); let states_to_token_maps_json = serde_json::to_string(&states_to_token_maps).unwrap();
println!( println!(
"🔥 states_to_token_maps size: {:.2}MB", "🔥 states_to_token_maps size: {:.2}MB",
states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0 states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0
); );
}
let result = regex_fsm.into_ref(py).extract().unwrap(); let result = regex_fsm.into_ref(py).extract().unwrap();
println!("result: {:?}", result); println!("result: {:?}", result);
Ok(result) Ok((result, states_to_token_maps))
}) })
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
let elapsed = start_time.elapsed(); let elapsed = start_time.elapsed();
println!("🔥 elapsed: {:?}", elapsed); println!("🔥 elapsed: {:?}", elapsed);
Ok(schema) Ok((schema, states_to_token_maps))
} }
type GrammarCompilationRequest = ( type GrammarCompilationRequest = (
String, String,
oneshot::Sender<Result<String, ValidationError>>, oneshot::Sender<Result<(String, StateTokenMaps), ValidationError>>,
Span, Span,
); );