feat: compile grammar and send over grpc

This commit is contained in:
drbh 2024-03-06 02:40:56 +00:00
parent ad5f562aa5
commit 4f7074ca71
7 changed files with 100 additions and 46 deletions

View File

@ -47,6 +47,7 @@ pub async fn run(
watermark,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
states_to_token_maps: None,
};
// Initialize terminal properties

View File

@ -80,6 +80,18 @@ message NextTokenChooserParameters {
string grammar = 10;
/// grammar type
GrammarType grammar_type = 11;
/// states to token maps
StatesToTokenMaps states_to_token_maps = 12;
}
/// StatesToTokenMaps maps to a BTreeMap<u32, BTreeMap<u32, u32>> in rust (start_state -> (token -> end_state))
message StatesToTokenMaps {
/// Start state
repeated uint32 start_states = 1;
/// Token
repeated uint32 tokens = 2;
/// End state
repeated uint32 end_states = 3;
}
message StoppingCriteriaParameters {

View File

@ -130,6 +130,7 @@ impl Client {
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
states_to_token_maps: None,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,

View File

@ -10,7 +10,7 @@ pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
NextTokenChooserParameters, Request, StatesToTokenMaps, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;

View File

@ -48,6 +48,7 @@ impl Health {
watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
states_to_token_maps: None,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,

View File

@ -372,6 +372,7 @@ mod tests {
watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
states_to_token_maps: None,
},
stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false,

View File

@ -7,7 +7,8 @@ use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng};
use serde_json::Value;
use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StatesToTokenMaps,
StoppingCriteriaParameters,
};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
@ -64,19 +65,36 @@ impl Validation {
});
}
// TODO: start more than one grammar compilation worker
let (grammar_sender, grammar_receiver) = mpsc::unbounded_channel();
// Create round robin channel
let (grammar_sender, grammar_round_robin_receiver) = mpsc::unbounded_channel();
let mut grammar_senders = Vec::with_capacity(workers);
// create single grammar compilation workers
// Create workers
for _ in 0..workers {
let tokenizer_clone = tokenizer.clone();
let (grammar_sender, grammar_receiver) = mpsc::unbounded_channel();
grammar_senders.push(grammar_sender);
// Spawn worker
tokio::task::spawn_blocking(move || {
grammar_compilation_worker(tokenizer, grammar_receiver).map_err(|e| {
grammar_compilation_worker(tokenizer_clone, grammar_receiver).map_err(|e| {
tracing::error!("Error in grammar compilation worker: {:?}", e);
e
})
});
}
// Create tokenization round robin task
tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));
tokio::spawn(round_robin_task::<TokenizerRequest>(
validation_round_robin_receiver,
senders,
));
// Create grammar compilation round robin task
tokio::spawn(round_robin_task::<GrammarCompilationRequest>(
grammar_round_robin_receiver,
grammar_senders,
));
(Some(validation_sender), Some(grammar_sender))
} else {
@ -121,7 +139,10 @@ impl Validation {
}
#[instrument(skip(self, inputs))]
pub async fn compile_grammar(&self, inputs: String) -> Result<String, ValidationError> {
pub async fn compile_grammar(
&self,
inputs: String,
) -> Result<(String, StateTokenMaps), ValidationError> {
// If we have a fast tokenizer
if let Some(sender) = &self.grammar_compilation_sender {
// Create response channel
@ -138,7 +159,7 @@ impl Validation {
return Ok(encoding);
}
Ok(inputs)
Ok((String::new(), BTreeMap::new()))
}
#[instrument(skip(self, inputs))]
@ -347,7 +368,7 @@ impl Validation {
// compiler and use that to build the FSM here.
// Validate grammar and unpack the grammar and type for the proto message
let (grammar, grammar_type) = match grammar {
let (grammar, grammar_type, states_to_token_maps) = match grammar {
Some(grammar) => {
// Ensure that grammar is not set if it's not supported
if self.disable_grammar_support {
@ -371,16 +392,27 @@ impl Validation {
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
// NOTE: this is the first step to compile the grammar
let regex_compiled_grammar = self
let (regex_compiled_grammar, _states_to_token_maps) = self
.compile_grammar(serde_json::to_string(&json).unwrap())
.await
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
(regex_compiled_grammar, ProtoGrammarType::Regex.into())
let stm = StatesToTokenMaps {
start_states: vec![],
tokens: vec![],
end_states: vec![],
};
(
regex_compiled_grammar,
ProtoGrammarType::Regex.into(),
Some(stm),
)
}
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into(), None),
}
}
None => (String::new(), ProtoGrammarType::None.into()),
None => (String::new(), ProtoGrammarType::None.into(), None),
};
let parameters = NextTokenChooserParameters {
@ -395,6 +427,7 @@ impl Validation {
watermark,
grammar,
grammar_type,
states_to_token_maps,
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,
@ -431,9 +464,9 @@ impl Validation {
}
/// Round robin tokenization task
async fn round_robin_task(
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,
async fn round_robin_task<T>(
mut receiver: mpsc::UnboundedReceiver<T>,
senders: Vec<mpsc::UnboundedSender<T>>,
) {
loop {
for sender in &senders {
@ -507,10 +540,15 @@ fn prepare_input(
Ok((encoding, inputs))
}
type StateTokenMaps = BTreeMap<u32, BTreeMap<u32, u32>>;
/// Compile a grammar
fn compile_grammar(inputs: String, tokenizer: &Tokenizer) -> Result<String, ValidationError> {
fn compile_grammar(
inputs: String,
tokenizer: &Tokenizer,
) -> Result<(String, StateTokenMaps), ValidationError> {
let start_time = std::time::Instant::now();
let schema = Python::with_gil(|py| -> PyResult<String> {
let (schema, states_to_token_maps) = Python::with_gil(|py| -> PyResult<(_, _)> {
let fun: Py<PyAny> = PyModule::from_code(
py,
r#"
@ -555,7 +593,7 @@ def adapt_tokenizer(vocab, special_tokens):
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
return tokenizer
def compile_regex_grammar(inputs, vocab):
def compile_regex_grammar(inputs, vocab, special_tokens):
start_time = time.time()
print("🔥 starting compile_regex_grammar", inputs)
schema = build_regex_from_schema(inputs)
@ -576,22 +614,23 @@ def convert_grammar_to_regex(inputs):
"",
"",
)?
// .getattr("compile_regex_grammar")?
.getattr("convert_grammar_to_regex")?
.into();
.into_py(py);
let convert_grammar_to_regex = fun.getattr(py, "convert_grammar_to_regex")?;
let compile_regex_grammar = fun.getattr(py, "compile_regex_grammar")?;
let args: &pyo3::types::PyDict = tokenizer.get_vocab(true).into_py_dict(py);
let special_tokens: Vec<String> = vec![];
let regex_fsm = fun.call(py, (inputs.clone(),), None)?;
let regex_fsm = convert_grammar_to_regex.call(py, (inputs.clone(),), None)?;
if false {
let regex_fsm = fun.call(py, (inputs.clone(), args, special_tokens), None)?;
let regex_fsm_ref = regex_fsm.into_ref(py);
let compiled_grammar =
compile_regex_grammar.call(py, (inputs.clone(), args, special_tokens), None)?;
let compiled_grammar_ref = compiled_grammar.into_ref(py);
let states_to_token_maps = regex_fsm_ref
let states_to_token_maps = compiled_grammar_ref
.getattr("states_to_token_maps")?
.extract::<BTreeMap<u32, BTreeMap<u32, u32>>>()?;
.extract::<StateTokenMaps>()?;
println!("🔥 elapsed: {:?}", start_time.elapsed());
@ -601,23 +640,22 @@ def convert_grammar_to_regex(inputs):
"🔥 states_to_token_maps size: {:.2}MB",
states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0
);
}
let result = regex_fsm.into_ref(py).extract().unwrap();
println!("result: {:?}", result);
Ok(result)
Ok((result, states_to_token_maps))
})
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
let elapsed = start_time.elapsed();
println!("🔥 elapsed: {:?}", elapsed);
Ok(schema)
Ok((schema, states_to_token_maps))
}
type GrammarCompilationRequest = (
String,
oneshot::Sender<Result<String, ValidationError>>,
oneshot::Sender<Result<(String, StateTokenMaps), ValidationError>>,
Span,
);