mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 20:42:06 +00:00
feat: compile grammar and send over grpc
This commit is contained in:
parent
ad5f562aa5
commit
4f7074ca71
@ -47,6 +47,7 @@ pub async fn run(
|
|||||||
watermark,
|
watermark,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
grammar_type: GrammarType::None as i32,
|
grammar_type: GrammarType::None as i32,
|
||||||
|
states_to_token_maps: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize terminal properties
|
// Initialize terminal properties
|
||||||
|
@ -80,6 +80,18 @@ message NextTokenChooserParameters {
|
|||||||
string grammar = 10;
|
string grammar = 10;
|
||||||
/// grammar type
|
/// grammar type
|
||||||
GrammarType grammar_type = 11;
|
GrammarType grammar_type = 11;
|
||||||
|
/// states to token maps
|
||||||
|
StatesToTokenMaps states_to_token_maps = 12;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// StatesToTokenMaps maps to a BTreeMap<u32, BTreeMap<u32, u32>> in rust (start_state -> (token -> end_state))
|
||||||
|
message StatesToTokenMaps {
|
||||||
|
/// Start state
|
||||||
|
repeated uint32 start_states = 1;
|
||||||
|
/// Token
|
||||||
|
repeated uint32 tokens = 2;
|
||||||
|
/// End state
|
||||||
|
repeated uint32 end_states = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoppingCriteriaParameters {
|
message StoppingCriteriaParameters {
|
||||||
|
@ -130,6 +130,7 @@ impl Client {
|
|||||||
watermark: true,
|
watermark: true,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
grammar_type: GrammarType::None as i32,
|
grammar_type: GrammarType::None as i32,
|
||||||
|
states_to_token_maps: None,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: max_total_tokens - truncate,
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
|
@ -10,7 +10,7 @@ pub use pb::generate::v2::HealthResponse;
|
|||||||
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
||||||
pub use pb::generate::v2::{
|
pub use pb::generate::v2::{
|
||||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
NextTokenChooserParameters, Request, StatesToTokenMaps, StoppingCriteriaParameters, Tokens,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
@ -48,6 +48,7 @@ impl Health {
|
|||||||
watermark: false,
|
watermark: false,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
grammar_type: ProtoGrammarType::None as i32,
|
grammar_type: ProtoGrammarType::None as i32,
|
||||||
|
states_to_token_maps: None,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
@ -372,6 +372,7 @@ mod tests {
|
|||||||
watermark: false,
|
watermark: false,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
grammar_type: ProtoGrammarType::None as i32,
|
grammar_type: ProtoGrammarType::None as i32,
|
||||||
|
states_to_token_maps: None,
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
|
@ -7,7 +7,8 @@ use jsonschema::{Draft, JSONSchema};
|
|||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StatesToTokenMaps,
|
||||||
|
StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
@ -64,19 +65,36 @@ impl Validation {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: start more than one grammar compilation worker
|
// Create round robin channel
|
||||||
let (grammar_sender, grammar_receiver) = mpsc::unbounded_channel();
|
let (grammar_sender, grammar_round_robin_receiver) = mpsc::unbounded_channel();
|
||||||
|
let mut grammar_senders = Vec::with_capacity(workers);
|
||||||
|
|
||||||
// create single grammar compilation workers
|
// Create workers
|
||||||
|
for _ in 0..workers {
|
||||||
|
let tokenizer_clone = tokenizer.clone();
|
||||||
|
let (grammar_sender, grammar_receiver) = mpsc::unbounded_channel();
|
||||||
|
grammar_senders.push(grammar_sender);
|
||||||
|
|
||||||
|
// Spawn worker
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
grammar_compilation_worker(tokenizer, grammar_receiver).map_err(|e| {
|
grammar_compilation_worker(tokenizer_clone, grammar_receiver).map_err(|e| {
|
||||||
tracing::error!("Error in grammar compilation worker: {:?}", e);
|
tracing::error!("Error in grammar compilation worker: {:?}", e);
|
||||||
e
|
e
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Create tokenization round robin task
|
// Create tokenization round robin task
|
||||||
tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));
|
tokio::spawn(round_robin_task::<TokenizerRequest>(
|
||||||
|
validation_round_robin_receiver,
|
||||||
|
senders,
|
||||||
|
));
|
||||||
|
|
||||||
|
// Create grammar compilation round robin task
|
||||||
|
tokio::spawn(round_robin_task::<GrammarCompilationRequest>(
|
||||||
|
grammar_round_robin_receiver,
|
||||||
|
grammar_senders,
|
||||||
|
));
|
||||||
|
|
||||||
(Some(validation_sender), Some(grammar_sender))
|
(Some(validation_sender), Some(grammar_sender))
|
||||||
} else {
|
} else {
|
||||||
@ -121,7 +139,10 @@ impl Validation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self, inputs))]
|
#[instrument(skip(self, inputs))]
|
||||||
pub async fn compile_grammar(&self, inputs: String) -> Result<String, ValidationError> {
|
pub async fn compile_grammar(
|
||||||
|
&self,
|
||||||
|
inputs: String,
|
||||||
|
) -> Result<(String, StateTokenMaps), ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some(sender) = &self.grammar_compilation_sender {
|
if let Some(sender) = &self.grammar_compilation_sender {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
@ -138,7 +159,7 @@ impl Validation {
|
|||||||
return Ok(encoding);
|
return Ok(encoding);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(inputs)
|
Ok((String::new(), BTreeMap::new()))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self, inputs))]
|
#[instrument(skip(self, inputs))]
|
||||||
@ -347,7 +368,7 @@ impl Validation {
|
|||||||
// compiler and use that to build the FSM here.
|
// compiler and use that to build the FSM here.
|
||||||
|
|
||||||
// Validate grammar and unpack the grammar and type for the proto message
|
// Validate grammar and unpack the grammar and type for the proto message
|
||||||
let (grammar, grammar_type) = match grammar {
|
let (grammar, grammar_type, states_to_token_maps) = match grammar {
|
||||||
Some(grammar) => {
|
Some(grammar) => {
|
||||||
// Ensure that grammar is not set if it's not supported
|
// Ensure that grammar is not set if it's not supported
|
||||||
if self.disable_grammar_support {
|
if self.disable_grammar_support {
|
||||||
@ -371,16 +392,27 @@ impl Validation {
|
|||||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||||
|
|
||||||
// NOTE: this is the first step to compile the grammar
|
// NOTE: this is the first step to compile the grammar
|
||||||
let regex_compiled_grammar = self
|
let (regex_compiled_grammar, _states_to_token_maps) = self
|
||||||
.compile_grammar(serde_json::to_string(&json).unwrap())
|
.compile_grammar(serde_json::to_string(&json).unwrap())
|
||||||
.await
|
.await
|
||||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||||
(regex_compiled_grammar, ProtoGrammarType::Regex.into())
|
|
||||||
|
let stm = StatesToTokenMaps {
|
||||||
|
start_states: vec![],
|
||||||
|
tokens: vec![],
|
||||||
|
end_states: vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
(
|
||||||
|
regex_compiled_grammar,
|
||||||
|
ProtoGrammarType::Regex.into(),
|
||||||
|
Some(stm),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
|
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into(), None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => (String::new(), ProtoGrammarType::None.into()),
|
None => (String::new(), ProtoGrammarType::None.into(), None),
|
||||||
};
|
};
|
||||||
|
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
@ -395,6 +427,7 @@ impl Validation {
|
|||||||
watermark,
|
watermark,
|
||||||
grammar,
|
grammar,
|
||||||
grammar_type,
|
grammar_type,
|
||||||
|
states_to_token_maps,
|
||||||
};
|
};
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
@ -431,9 +464,9 @@ impl Validation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Round robin tokenization task
|
/// Round robin tokenization task
|
||||||
async fn round_robin_task(
|
async fn round_robin_task<T>(
|
||||||
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
mut receiver: mpsc::UnboundedReceiver<T>,
|
||||||
senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,
|
senders: Vec<mpsc::UnboundedSender<T>>,
|
||||||
) {
|
) {
|
||||||
loop {
|
loop {
|
||||||
for sender in &senders {
|
for sender in &senders {
|
||||||
@ -507,10 +540,15 @@ fn prepare_input(
|
|||||||
Ok((encoding, inputs))
|
Ok((encoding, inputs))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StateTokenMaps = BTreeMap<u32, BTreeMap<u32, u32>>;
|
||||||
|
|
||||||
/// Compile a grammar
|
/// Compile a grammar
|
||||||
fn compile_grammar(inputs: String, tokenizer: &Tokenizer) -> Result<String, ValidationError> {
|
fn compile_grammar(
|
||||||
|
inputs: String,
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
) -> Result<(String, StateTokenMaps), ValidationError> {
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
let schema = Python::with_gil(|py| -> PyResult<String> {
|
let (schema, states_to_token_maps) = Python::with_gil(|py| -> PyResult<(_, _)> {
|
||||||
let fun: Py<PyAny> = PyModule::from_code(
|
let fun: Py<PyAny> = PyModule::from_code(
|
||||||
py,
|
py,
|
||||||
r#"
|
r#"
|
||||||
@ -555,7 +593,7 @@ def adapt_tokenizer(vocab, special_tokens):
|
|||||||
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
|
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
def compile_regex_grammar(inputs, vocab):
|
def compile_regex_grammar(inputs, vocab, special_tokens):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
print("🔥 starting compile_regex_grammar", inputs)
|
print("🔥 starting compile_regex_grammar", inputs)
|
||||||
schema = build_regex_from_schema(inputs)
|
schema = build_regex_from_schema(inputs)
|
||||||
@ -576,22 +614,23 @@ def convert_grammar_to_regex(inputs):
|
|||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
)?
|
)?
|
||||||
// .getattr("compile_regex_grammar")?
|
.into_py(py);
|
||||||
.getattr("convert_grammar_to_regex")?
|
|
||||||
.into();
|
let convert_grammar_to_regex = fun.getattr(py, "convert_grammar_to_regex")?;
|
||||||
|
let compile_regex_grammar = fun.getattr(py, "compile_regex_grammar")?;
|
||||||
|
|
||||||
let args: &pyo3::types::PyDict = tokenizer.get_vocab(true).into_py_dict(py);
|
let args: &pyo3::types::PyDict = tokenizer.get_vocab(true).into_py_dict(py);
|
||||||
let special_tokens: Vec<String> = vec![];
|
let special_tokens: Vec<String> = vec![];
|
||||||
|
|
||||||
let regex_fsm = fun.call(py, (inputs.clone(),), None)?;
|
let regex_fsm = convert_grammar_to_regex.call(py, (inputs.clone(),), None)?;
|
||||||
|
|
||||||
if false {
|
let compiled_grammar =
|
||||||
let regex_fsm = fun.call(py, (inputs.clone(), args, special_tokens), None)?;
|
compile_regex_grammar.call(py, (inputs.clone(), args, special_tokens), None)?;
|
||||||
let regex_fsm_ref = regex_fsm.into_ref(py);
|
let compiled_grammar_ref = compiled_grammar.into_ref(py);
|
||||||
|
|
||||||
let states_to_token_maps = regex_fsm_ref
|
let states_to_token_maps = compiled_grammar_ref
|
||||||
.getattr("states_to_token_maps")?
|
.getattr("states_to_token_maps")?
|
||||||
.extract::<BTreeMap<u32, BTreeMap<u32, u32>>>()?;
|
.extract::<StateTokenMaps>()?;
|
||||||
|
|
||||||
println!("🔥 elapsed: {:?}", start_time.elapsed());
|
println!("🔥 elapsed: {:?}", start_time.elapsed());
|
||||||
|
|
||||||
@ -601,23 +640,22 @@ def convert_grammar_to_regex(inputs):
|
|||||||
"🔥 states_to_token_maps size: {:.2}MB",
|
"🔥 states_to_token_maps size: {:.2}MB",
|
||||||
states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0
|
states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0
|
||||||
);
|
);
|
||||||
}
|
|
||||||
|
|
||||||
let result = regex_fsm.into_ref(py).extract().unwrap();
|
let result = regex_fsm.into_ref(py).extract().unwrap();
|
||||||
|
|
||||||
println!("result: {:?}", result);
|
println!("result: {:?}", result);
|
||||||
|
|
||||||
Ok(result)
|
Ok((result, states_to_token_maps))
|
||||||
})
|
})
|
||||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||||
let elapsed = start_time.elapsed();
|
let elapsed = start_time.elapsed();
|
||||||
println!("🔥 elapsed: {:?}", elapsed);
|
println!("🔥 elapsed: {:?}", elapsed);
|
||||||
Ok(schema)
|
Ok((schema, states_to_token_maps))
|
||||||
}
|
}
|
||||||
|
|
||||||
type GrammarCompilationRequest = (
|
type GrammarCompilationRequest = (
|
||||||
String,
|
String,
|
||||||
oneshot::Sender<Result<String, ValidationError>>,
|
oneshot::Sender<Result<(String, StateTokenMaps), ValidationError>>,
|
||||||
Span,
|
Span,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user