mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 20:12:07 +00:00
feat: compile grammar and send over grpc
This commit is contained in:
parent
ad5f562aa5
commit
4f7074ca71
@ -47,6 +47,7 @@ pub async fn run(
|
||||
watermark,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
states_to_token_maps: None,
|
||||
};
|
||||
|
||||
// Initialize terminal properties
|
||||
|
@ -80,6 +80,18 @@ message NextTokenChooserParameters {
|
||||
string grammar = 10;
|
||||
/// grammar type
|
||||
GrammarType grammar_type = 11;
|
||||
/// states to token maps
|
||||
StatesToTokenMaps states_to_token_maps = 12;
|
||||
}
|
||||
|
||||
/// StatesToTokenMaps maps to a BTreeMap<u32, BTreeMap<u32, u32>> in rust (start_state -> (token -> end_state))
|
||||
message StatesToTokenMaps {
|
||||
/// Start state
|
||||
repeated uint32 start_states = 1;
|
||||
/// Token
|
||||
repeated uint32 tokens = 2;
|
||||
/// End state
|
||||
repeated uint32 end_states = 3;
|
||||
}
|
||||
|
||||
message StoppingCriteriaParameters {
|
||||
|
@ -130,6 +130,7 @@ impl Client {
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
states_to_token_maps: None,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: max_total_tokens - truncate,
|
||||
|
@ -10,7 +10,7 @@ pub use pb::generate::v2::HealthResponse;
|
||||
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
||||
pub use pb::generate::v2::{
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
||||
NextTokenChooserParameters, Request, StatesToTokenMaps, StoppingCriteriaParameters, Tokens,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
use thiserror::Error;
|
||||
|
@ -48,6 +48,7 @@ impl Health {
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: ProtoGrammarType::None as i32,
|
||||
states_to_token_maps: None,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
|
@ -372,6 +372,7 @@ mod tests {
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: ProtoGrammarType::None as i32,
|
||||
states_to_token_maps: None,
|
||||
},
|
||||
stopping_parameters: StoppingCriteriaParameters {
|
||||
ignore_eos_token: false,
|
||||
|
@ -7,7 +7,8 @@ use jsonschema::{Draft, JSONSchema};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
use text_generation_client::{
|
||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StatesToTokenMaps,
|
||||
StoppingCriteriaParameters,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
@ -64,19 +65,36 @@ impl Validation {
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: start more than one grammar compilation worker
|
||||
let (grammar_sender, grammar_receiver) = mpsc::unbounded_channel();
|
||||
// Create round robin channel
|
||||
let (grammar_sender, grammar_round_robin_receiver) = mpsc::unbounded_channel();
|
||||
let mut grammar_senders = Vec::with_capacity(workers);
|
||||
|
||||
// create single grammar compilation workers
|
||||
tokio::task::spawn_blocking(move || {
|
||||
grammar_compilation_worker(tokenizer, grammar_receiver).map_err(|e| {
|
||||
tracing::error!("Error in grammar compilation worker: {:?}", e);
|
||||
e
|
||||
})
|
||||
});
|
||||
// Create workers
|
||||
for _ in 0..workers {
|
||||
let tokenizer_clone = tokenizer.clone();
|
||||
let (grammar_sender, grammar_receiver) = mpsc::unbounded_channel();
|
||||
grammar_senders.push(grammar_sender);
|
||||
|
||||
// Spawn worker
|
||||
tokio::task::spawn_blocking(move || {
|
||||
grammar_compilation_worker(tokenizer_clone, grammar_receiver).map_err(|e| {
|
||||
tracing::error!("Error in grammar compilation worker: {:?}", e);
|
||||
e
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
// Create tokenization round robin task
|
||||
tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));
|
||||
tokio::spawn(round_robin_task::<TokenizerRequest>(
|
||||
validation_round_robin_receiver,
|
||||
senders,
|
||||
));
|
||||
|
||||
// Create grammar compilation round robin task
|
||||
tokio::spawn(round_robin_task::<GrammarCompilationRequest>(
|
||||
grammar_round_robin_receiver,
|
||||
grammar_senders,
|
||||
));
|
||||
|
||||
(Some(validation_sender), Some(grammar_sender))
|
||||
} else {
|
||||
@ -121,7 +139,10 @@ impl Validation {
|
||||
}
|
||||
|
||||
#[instrument(skip(self, inputs))]
|
||||
pub async fn compile_grammar(&self, inputs: String) -> Result<String, ValidationError> {
|
||||
pub async fn compile_grammar(
|
||||
&self,
|
||||
inputs: String,
|
||||
) -> Result<(String, StateTokenMaps), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some(sender) = &self.grammar_compilation_sender {
|
||||
// Create response channel
|
||||
@ -138,7 +159,7 @@ impl Validation {
|
||||
return Ok(encoding);
|
||||
}
|
||||
|
||||
Ok(inputs)
|
||||
Ok((String::new(), BTreeMap::new()))
|
||||
}
|
||||
|
||||
#[instrument(skip(self, inputs))]
|
||||
@ -347,7 +368,7 @@ impl Validation {
|
||||
// compiler and use that to build the FSM here.
|
||||
|
||||
// Validate grammar and unpack the grammar and type for the proto message
|
||||
let (grammar, grammar_type) = match grammar {
|
||||
let (grammar, grammar_type, states_to_token_maps) = match grammar {
|
||||
Some(grammar) => {
|
||||
// Ensure that grammar is not set if it's not supported
|
||||
if self.disable_grammar_support {
|
||||
@ -371,16 +392,27 @@ impl Validation {
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||
|
||||
// NOTE: this is the first step to compile the grammar
|
||||
let regex_compiled_grammar = self
|
||||
let (regex_compiled_grammar, _states_to_token_maps) = self
|
||||
.compile_grammar(serde_json::to_string(&json).unwrap())
|
||||
.await
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||
(regex_compiled_grammar, ProtoGrammarType::Regex.into())
|
||||
|
||||
let stm = StatesToTokenMaps {
|
||||
start_states: vec![],
|
||||
tokens: vec![],
|
||||
end_states: vec![],
|
||||
};
|
||||
|
||||
(
|
||||
regex_compiled_grammar,
|
||||
ProtoGrammarType::Regex.into(),
|
||||
Some(stm),
|
||||
)
|
||||
}
|
||||
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
|
||||
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into(), None),
|
||||
}
|
||||
}
|
||||
None => (String::new(), ProtoGrammarType::None.into()),
|
||||
None => (String::new(), ProtoGrammarType::None.into(), None),
|
||||
};
|
||||
|
||||
let parameters = NextTokenChooserParameters {
|
||||
@ -395,6 +427,7 @@ impl Validation {
|
||||
watermark,
|
||||
grammar,
|
||||
grammar_type,
|
||||
states_to_token_maps,
|
||||
};
|
||||
let stopping_parameters = StoppingCriteriaParameters {
|
||||
max_new_tokens,
|
||||
@ -431,9 +464,9 @@ impl Validation {
|
||||
}
|
||||
|
||||
/// Round robin tokenization task
|
||||
async fn round_robin_task(
|
||||
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
||||
senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,
|
||||
async fn round_robin_task<T>(
|
||||
mut receiver: mpsc::UnboundedReceiver<T>,
|
||||
senders: Vec<mpsc::UnboundedSender<T>>,
|
||||
) {
|
||||
loop {
|
||||
for sender in &senders {
|
||||
@ -507,10 +540,15 @@ fn prepare_input(
|
||||
Ok((encoding, inputs))
|
||||
}
|
||||
|
||||
type StateTokenMaps = BTreeMap<u32, BTreeMap<u32, u32>>;
|
||||
|
||||
/// Compile a grammar
|
||||
fn compile_grammar(inputs: String, tokenizer: &Tokenizer) -> Result<String, ValidationError> {
|
||||
fn compile_grammar(
|
||||
inputs: String,
|
||||
tokenizer: &Tokenizer,
|
||||
) -> Result<(String, StateTokenMaps), ValidationError> {
|
||||
let start_time = std::time::Instant::now();
|
||||
let schema = Python::with_gil(|py| -> PyResult<String> {
|
||||
let (schema, states_to_token_maps) = Python::with_gil(|py| -> PyResult<(_, _)> {
|
||||
let fun: Py<PyAny> = PyModule::from_code(
|
||||
py,
|
||||
r#"
|
||||
@ -555,7 +593,7 @@ def adapt_tokenizer(vocab, special_tokens):
|
||||
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
|
||||
return tokenizer
|
||||
|
||||
def compile_regex_grammar(inputs, vocab):
|
||||
def compile_regex_grammar(inputs, vocab, special_tokens):
|
||||
start_time = time.time()
|
||||
print("🔥 starting compile_regex_grammar", inputs)
|
||||
schema = build_regex_from_schema(inputs)
|
||||
@ -576,48 +614,48 @@ def convert_grammar_to_regex(inputs):
|
||||
"",
|
||||
"",
|
||||
)?
|
||||
// .getattr("compile_regex_grammar")?
|
||||
.getattr("convert_grammar_to_regex")?
|
||||
.into();
|
||||
.into_py(py);
|
||||
|
||||
let convert_grammar_to_regex = fun.getattr(py, "convert_grammar_to_regex")?;
|
||||
let compile_regex_grammar = fun.getattr(py, "compile_regex_grammar")?;
|
||||
|
||||
let args: &pyo3::types::PyDict = tokenizer.get_vocab(true).into_py_dict(py);
|
||||
let special_tokens: Vec<String> = vec![];
|
||||
|
||||
let regex_fsm = fun.call(py, (inputs.clone(),), None)?;
|
||||
let regex_fsm = convert_grammar_to_regex.call(py, (inputs.clone(),), None)?;
|
||||
|
||||
if false {
|
||||
let regex_fsm = fun.call(py, (inputs.clone(), args, special_tokens), None)?;
|
||||
let regex_fsm_ref = regex_fsm.into_ref(py);
|
||||
let compiled_grammar =
|
||||
compile_regex_grammar.call(py, (inputs.clone(), args, special_tokens), None)?;
|
||||
let compiled_grammar_ref = compiled_grammar.into_ref(py);
|
||||
|
||||
let states_to_token_maps = regex_fsm_ref
|
||||
.getattr("states_to_token_maps")?
|
||||
.extract::<BTreeMap<u32, BTreeMap<u32, u32>>>()?;
|
||||
let states_to_token_maps = compiled_grammar_ref
|
||||
.getattr("states_to_token_maps")?
|
||||
.extract::<StateTokenMaps>()?;
|
||||
|
||||
println!("🔥 elapsed: {:?}", start_time.elapsed());
|
||||
println!("🔥 elapsed: {:?}", start_time.elapsed());
|
||||
|
||||
// size of serialized states_to_token_maps
|
||||
let states_to_token_maps_json = serde_json::to_string(&states_to_token_maps).unwrap();
|
||||
println!(
|
||||
"🔥 states_to_token_maps size: {:.2}MB",
|
||||
states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0
|
||||
);
|
||||
}
|
||||
// size of serialized states_to_token_maps
|
||||
let states_to_token_maps_json = serde_json::to_string(&states_to_token_maps).unwrap();
|
||||
println!(
|
||||
"🔥 states_to_token_maps size: {:.2}MB",
|
||||
states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0
|
||||
);
|
||||
|
||||
let result = regex_fsm.into_ref(py).extract().unwrap();
|
||||
|
||||
println!("result: {:?}", result);
|
||||
|
||||
Ok(result)
|
||||
Ok((result, states_to_token_maps))
|
||||
})
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||
let elapsed = start_time.elapsed();
|
||||
println!("🔥 elapsed: {:?}", elapsed);
|
||||
Ok(schema)
|
||||
Ok((schema, states_to_token_maps))
|
||||
}
|
||||
|
||||
type GrammarCompilationRequest = (
|
||||
String,
|
||||
oneshot::Sender<Result<String, ValidationError>>,
|
||||
oneshot::Sender<Result<(String, StateTokenMaps), ValidationError>>,
|
||||
Span,
|
||||
);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user