mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 02:32:14 +00:00
feat: remove uncompile grammar and improve logit processor logic
This commit is contained in:
parent
c52a0f679e
commit
1f7be736d2
@ -250,5 +250,7 @@ ENTRYPOINT ["./entrypoint.sh"]
|
||||
# Final image
|
||||
FROM base
|
||||
|
||||
ENV LD_LIBRARY_PATH=/opt/conda/lib/:$LD_LIBRARY_PATH
|
||||
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
CMD ["--json-output"]
|
||||
|
@ -8,7 +8,7 @@ use crate::app::App;
|
||||
use crate::event::Event;
|
||||
use crossterm::ExecutableCommand;
|
||||
use std::io;
|
||||
use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient};
|
||||
use text_generation_client::{NextTokenChooserParameters, ShardedClient};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use tui::backend::CrosstermBackend;
|
||||
@ -45,8 +45,6 @@ pub async fn run(
|
||||
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
||||
frequency_penalty: frequency_penalty.unwrap_or(0.0),
|
||||
watermark,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
states_to_token_maps: None,
|
||||
};
|
||||
|
||||
|
@ -51,12 +51,6 @@ message ClearCacheRequest {
|
||||
/// Empty response
|
||||
message ClearCacheResponse {}
|
||||
|
||||
enum GrammarType {
|
||||
GRAMMAR_TYPE_NONE = 0;
|
||||
GRAMMAR_TYPE_JSON = 1;
|
||||
GRAMMAR_TYPE_REGEX = 2;
|
||||
}
|
||||
|
||||
message NextTokenChooserParameters {
|
||||
/// exponential scaling output probability distribution
|
||||
float temperature = 1;
|
||||
@ -76,10 +70,6 @@ message NextTokenChooserParameters {
|
||||
float frequency_penalty = 9;
|
||||
/// token watermarking using "A Watermark for Large Language Models"
|
||||
bool watermark = 8;
|
||||
/// grammar (applied if not empty)
|
||||
string grammar = 10;
|
||||
/// grammar type
|
||||
GrammarType grammar_type = 11;
|
||||
/// states to token maps
|
||||
StatesToTokenMaps states_to_token_maps = 12;
|
||||
}
|
||||
|
@ -135,8 +135,6 @@ impl Client {
|
||||
repetition_penalty: 1.2,
|
||||
frequency_penalty: 0.1,
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
states_to_token_maps: None,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
|
@ -9,8 +9,8 @@ pub use client::Client;
|
||||
pub use pb::generate::v2::HealthResponse;
|
||||
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
||||
pub use pb::generate::v2::{
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
NextTokenChooserParameters, Request, StatesToTokenMaps, StoppingCriteriaParameters, Tokens,
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
||||
Request, StatesToTokenMaps, StoppingCriteriaParameters, Tokens,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
use thiserror::Error;
|
||||
|
@ -1,6 +1,5 @@
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::GrammarType as ProtoGrammarType;
|
||||
use text_generation_client::{
|
||||
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||
};
|
||||
@ -46,8 +45,6 @@ impl Health {
|
||||
repetition_penalty: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: ProtoGrammarType::None as i32,
|
||||
states_to_token_maps: None,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
|
@ -343,9 +343,7 @@ enum QueueCommand {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use text_generation_client::{
|
||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
||||
};
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||
use tracing::info_span;
|
||||
|
||||
fn default_entry() -> (
|
||||
@ -370,8 +368,6 @@ mod tests {
|
||||
repetition_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: ProtoGrammarType::None as i32,
|
||||
states_to_token_maps: None,
|
||||
},
|
||||
stopping_parameters: StoppingCriteriaParameters {
|
||||
|
@ -7,8 +7,7 @@ use jsonschema::{Draft, JSONSchema};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
use text_generation_client::{
|
||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StatesToTokenMaps,
|
||||
StoppingCriteriaParameters,
|
||||
NextTokenChooserParameters, StatesToTokenMaps, StoppingCriteriaParameters,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
@ -368,7 +367,7 @@ impl Validation {
|
||||
// compiler and use that to build the FSM here.
|
||||
|
||||
// Validate grammar and unpack the grammar and type for the proto message
|
||||
let (grammar, grammar_type, states_to_token_maps) = match grammar {
|
||||
let states_to_token_maps = match grammar {
|
||||
Some(grammar) => {
|
||||
// Ensure that grammar is not set if it's not supported
|
||||
if self.disable_grammar_support {
|
||||
@ -392,7 +391,7 @@ impl Validation {
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||
|
||||
// NOTE: this is the first step to compile the grammar
|
||||
let (regex_compiled_grammar, _states_to_token_maps) = self
|
||||
let (_regex_compiled_grammar, _states_to_token_maps) = self
|
||||
.compile_grammar(serde_json::to_string(&json).unwrap())
|
||||
.await
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||
@ -416,16 +415,12 @@ impl Validation {
|
||||
end_states,
|
||||
};
|
||||
|
||||
(
|
||||
regex_compiled_grammar,
|
||||
ProtoGrammarType::Regex.into(),
|
||||
Some(stm),
|
||||
)
|
||||
Some(stm)
|
||||
}
|
||||
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into(), None),
|
||||
GrammarType::Regex(_regex) => None,
|
||||
}
|
||||
}
|
||||
None => (String::new(), ProtoGrammarType::None.into(), None),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let parameters = NextTokenChooserParameters {
|
||||
@ -438,8 +433,6 @@ impl Validation {
|
||||
do_sample,
|
||||
seed,
|
||||
watermark,
|
||||
grammar,
|
||||
grammar_type,
|
||||
states_to_token_maps,
|
||||
};
|
||||
let stopping_parameters = StoppingCriteriaParameters {
|
||||
@ -567,7 +560,6 @@ fn compile_grammar(
|
||||
r#"
|
||||
from outlines.fsm.fsm import RegexFSM
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
import time
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
class Tokenizer:
|
||||
@ -589,13 +581,10 @@ class Tokenizer:
|
||||
return " ".join(tokens)
|
||||
|
||||
def adapt_tokenizer(vocab, special_tokens):
|
||||
start_time = time.time()
|
||||
tokenizer = Tokenizer(vocab, special_tokens)
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
@ -603,26 +592,16 @@ def adapt_tokenizer(vocab, special_tokens):
|
||||
return string
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
|
||||
return tokenizer
|
||||
|
||||
def compile_regex_grammar(inputs, vocab, special_tokens):
|
||||
start_time = time.time()
|
||||
print("🔥 starting compile_regex_grammar", inputs)
|
||||
schema = build_regex_from_schema(inputs)
|
||||
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
|
||||
tokenizer = adapt_tokenizer(vocab, special_tokens)
|
||||
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
|
||||
fsm = RegexFSM(schema, tokenizer)
|
||||
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
|
||||
return fsm
|
||||
|
||||
def convert_grammar_to_regex(inputs):
|
||||
start_time = time.time()
|
||||
print("🔥 starting convert_grammar_to_regex", inputs)
|
||||
schema = build_regex_from_schema(inputs)
|
||||
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
|
||||
return schema
|
||||
return build_regex_from_schema(inputs)
|
||||
"#,
|
||||
"",
|
||||
"",
|
||||
@ -645,19 +624,7 @@ def convert_grammar_to_regex(inputs):
|
||||
.getattr("states_to_token_maps")?
|
||||
.extract::<StateTokenMaps>()?;
|
||||
|
||||
println!("🔥 elapsed: {:?}", start_time.elapsed());
|
||||
|
||||
// size of serialized states_to_token_maps
|
||||
let states_to_token_maps_json = serde_json::to_string(&states_to_token_maps).unwrap();
|
||||
println!(
|
||||
"🔥 states_to_token_maps size: {:.2}MB",
|
||||
states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0
|
||||
);
|
||||
|
||||
let result = regex_fsm.into_ref(py).extract().unwrap();
|
||||
|
||||
println!("result: {:?}", result);
|
||||
|
||||
Ok((result, states_to_token_maps))
|
||||
})
|
||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||
|
@ -475,22 +475,26 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
fsm_state: DefaultDict[int, int]
|
||||
fsm: RegexFSM
|
||||
|
||||
def __init__(self, tokenizer, device, grammar, grammar_type, states_to_token_maps):
|
||||
def __init__(self, tokenizer, device, states_to_token_maps):
|
||||
self.device = device
|
||||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||
|
||||
# TODO: use the precompiled grammar here
|
||||
self.states_to_token_maps = states_to_token_maps
|
||||
precompiled_grammar = RegexFSM.precompiled(
|
||||
states_to_token_maps=states_to_token_maps,
|
||||
empty_token_ids=None,
|
||||
vocabulary=None,
|
||||
eos_token_id=None,
|
||||
)
|
||||
start_states = states_to_token_maps.start_states
|
||||
tokens = states_to_token_maps.tokens
|
||||
end_states = states_to_token_maps.end_states
|
||||
|
||||
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||
grammar_type, grammar, self.tokenizer
|
||||
)
|
||||
_states_to_token_maps = {}
|
||||
for i in range(len(start_states)):
|
||||
if start_states[i] in _states_to_token_maps:
|
||||
_states_to_token_maps[start_states[i]][tokens[i]] = end_states[i]
|
||||
else:
|
||||
_states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]}
|
||||
|
||||
fsm = object.__new__(RegexFSM)
|
||||
fsm.states_to_token_maps = _states_to_token_maps
|
||||
fsm.empty_token_ids = None
|
||||
fsm.vocabulary = list(tokenizer.get_vocab().values())
|
||||
fsm.eos_token_id = tokenizer.eos_token_id
|
||||
self.fsm = fsm
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -560,17 +564,14 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
|
||||
|
||||
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
def __init__(
|
||||
self, tokenizer, device, grammars, grammar_types, states_to_token_maps
|
||||
):
|
||||
def __init__(self, tokenizer, device, states_to_token_maps):
|
||||
self.device = device
|
||||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||
self.fsms = []
|
||||
|
||||
for grammar, grammar_type in zip(grammars, grammar_types):
|
||||
start_states = states_to_token_maps[0].start_states
|
||||
tokens = states_to_token_maps[0].tokens
|
||||
end_states = states_to_token_maps[0].end_states
|
||||
for states_to_token_map in states_to_token_maps:
|
||||
start_states = states_to_token_map.start_states
|
||||
tokens = states_to_token_map.tokens
|
||||
end_states = states_to_token_map.end_states
|
||||
|
||||
_states_to_token_maps = {}
|
||||
for i in range(len(start_states)):
|
||||
@ -579,18 +580,11 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
else:
|
||||
_states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]}
|
||||
|
||||
# TODO: cleanup how precompiled grammars are handled
|
||||
precompiled_grammar = RegexFSM.precompiled(
|
||||
states_to_token_maps=_states_to_token_maps,
|
||||
empty_token_ids=None,
|
||||
vocabulary=list(tokenizer.get_vocab().values()),
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
)
|
||||
# fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||
# grammar_type, grammar, self.tokenizer
|
||||
# )
|
||||
|
||||
fsm = precompiled_grammar
|
||||
fsm = object.__new__(RegexFSM)
|
||||
fsm.states_to_token_maps = _states_to_token_maps
|
||||
fsm.empty_token_ids = None
|
||||
fsm.vocabulary = list(tokenizer.get_vocab().values())
|
||||
fsm.eos_token_id = tokenizer.eos_token_id
|
||||
self.fsms.append(fsm)
|
||||
|
||||
def __call__(
|
||||
|
@ -36,8 +36,6 @@ class NextTokenChooser:
|
||||
seed: int = 0,
|
||||
device: str = "cpu",
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
grammar: str = "",
|
||||
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
|
||||
fsm_grammar_state: int = 0,
|
||||
states_to_token_maps: List[List[int]] = None,
|
||||
):
|
||||
@ -55,10 +53,8 @@ class NextTokenChooser:
|
||||
else None
|
||||
)
|
||||
self.grammar_processor = (
|
||||
GrammarLogitProcessor(
|
||||
tokenizer, device, grammar, grammar_type, states_to_token_maps
|
||||
)
|
||||
if grammar != ""
|
||||
GrammarLogitProcessor(tokenizer, device, states_to_token_maps)
|
||||
if states_to_token_maps
|
||||
else None
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
@ -80,7 +76,6 @@ class NextTokenChooser:
|
||||
|
||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||
self.fsm_grammar_state = fsm_grammar_state
|
||||
self.grammar = grammar
|
||||
self.states_to_token_maps = states_to_token_maps
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
@ -128,8 +123,6 @@ class NextTokenChooser:
|
||||
seed=pb.seed,
|
||||
device=device,
|
||||
tokenizer=tokenizer,
|
||||
grammar=pb.grammar,
|
||||
grammar_type=pb.grammar_type,
|
||||
states_to_token_maps=pb.states_to_token_maps,
|
||||
)
|
||||
|
||||
@ -236,8 +229,6 @@ class HeterogeneousNextTokenChooser:
|
||||
do_sample: List[bool],
|
||||
seeds: List[int],
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
grammars: List[str],
|
||||
grammar_types: List[int],
|
||||
fsm_grammar_states: List[int],
|
||||
states_to_token_maps: List[List[List[int]]],
|
||||
):
|
||||
@ -272,10 +263,8 @@ class HeterogeneousNextTokenChooser:
|
||||
)
|
||||
|
||||
self.grammar_processor = (
|
||||
HeterogeneousGrammarLogitProcessor(
|
||||
tokenizer, device, grammars, grammar_types, states_to_token_maps
|
||||
)
|
||||
if any([grammar != "" for grammar in grammars])
|
||||
HeterogeneousGrammarLogitProcessor(tokenizer, device, states_to_token_maps)
|
||||
if any(states_to_token_maps)
|
||||
else None
|
||||
)
|
||||
|
||||
@ -312,8 +301,6 @@ class HeterogeneousNextTokenChooser:
|
||||
self.device = device
|
||||
self.tokenizer = tokenizer
|
||||
self.fsm_grammar_states = fsm_grammar_states
|
||||
self.grammars = grammars
|
||||
self.grammar_types = grammar_types
|
||||
self.states_to_token_maps = states_to_token_maps
|
||||
|
||||
def __call__(
|
||||
@ -447,17 +434,11 @@ class HeterogeneousNextTokenChooser:
|
||||
self.seeds = [self.seeds[i] for i in indices]
|
||||
self.do_sample = [self.do_sample[i] for i in indices]
|
||||
|
||||
new_grammars = []
|
||||
new_fsm_grammar_states = []
|
||||
new_grammar_types = []
|
||||
for i in indices:
|
||||
new_grammars.append(self.grammars[i])
|
||||
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
|
||||
new_grammar_types.append(self.grammar_types[i])
|
||||
|
||||
self.grammars = new_grammars
|
||||
self.fsm_grammar_states = new_fsm_grammar_states
|
||||
self.grammar_types = new_grammar_types
|
||||
|
||||
if any(self.do_sample):
|
||||
self.choice.filter(indices)
|
||||
@ -488,8 +469,6 @@ class HeterogeneousNextTokenChooser:
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
tokenizer=tokenizer,
|
||||
grammars=[pb_.grammar for pb_ in pb],
|
||||
grammar_types=[pb_.grammar_type for pb_ in pb],
|
||||
fsm_grammar_states=(
|
||||
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
||||
),
|
||||
|
Loading…
Reference in New Issue
Block a user