feat: remove uncompile grammar and improve logit processor logic

This commit is contained in:
drbh 2024-03-08 03:34:19 +00:00
parent c52a0f679e
commit 1f7be736d2
10 changed files with 44 additions and 123 deletions

View File

@ -250,5 +250,7 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base
ENV LD_LIBRARY_PATH=/opt/conda/lib/:$LD_LIBRARY_PATH
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]

View File

@ -8,7 +8,7 @@ use crate::app::App;
use crate::event::Event;
use crossterm::ExecutableCommand;
use std::io;
use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient};
use text_generation_client::{NextTokenChooserParameters, ShardedClient};
use tokenizers::Tokenizer;
use tokio::sync::{broadcast, mpsc};
use tui::backend::CrosstermBackend;
@ -45,8 +45,6 @@ pub async fn run(
repetition_penalty: repetition_penalty.unwrap_or(1.0),
frequency_penalty: frequency_penalty.unwrap_or(0.0),
watermark,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
states_to_token_maps: None,
};

View File

@ -51,12 +51,6 @@ message ClearCacheRequest {
/// Empty response
message ClearCacheResponse {}
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2;
}
message NextTokenChooserParameters {
/// exponential scaling output probability distribution
float temperature = 1;
@ -76,10 +70,6 @@ message NextTokenChooserParameters {
float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
/// grammar (applied if not empty)
string grammar = 10;
/// grammar type
GrammarType grammar_type = 11;
/// states to token maps
StatesToTokenMaps states_to_token_maps = 12;
}

View File

@ -135,8 +135,6 @@ impl Client {
repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
states_to_token_maps: None,
}),
stopping_parameters: Some(StoppingCriteriaParameters {

View File

@ -9,8 +9,8 @@ pub use client::Client;
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
NextTokenChooserParameters, Request, StatesToTokenMaps, StoppingCriteriaParameters, Tokens,
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
Request, StatesToTokenMaps, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;

View File

@ -1,6 +1,5 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::GrammarType as ProtoGrammarType;
use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
};
@ -46,8 +45,6 @@ impl Health {
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
states_to_token_maps: None,
}),
stopping_parameters: Some(StoppingCriteriaParameters {

View File

@ -343,9 +343,7 @@ enum QueueCommand {
#[cfg(test)]
mod tests {
use super::*;
use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use tracing::info_span;
fn default_entry() -> (
@ -370,8 +368,6 @@ mod tests {
repetition_penalty: 0.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
states_to_token_maps: None,
},
stopping_parameters: StoppingCriteriaParameters {

View File

@ -7,8 +7,7 @@ use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng};
use serde_json::Value;
use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StatesToTokenMaps,
StoppingCriteriaParameters,
NextTokenChooserParameters, StatesToTokenMaps, StoppingCriteriaParameters,
};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
@ -368,7 +367,7 @@ impl Validation {
// compiler and use that to build the FSM here.
// Validate grammar and unpack the grammar and type for the proto message
let (grammar, grammar_type, states_to_token_maps) = match grammar {
let states_to_token_maps = match grammar {
Some(grammar) => {
// Ensure that grammar is not set if it's not supported
if self.disable_grammar_support {
@ -392,7 +391,7 @@ impl Validation {
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
// NOTE: this is the first step to compile the grammar
let (regex_compiled_grammar, _states_to_token_maps) = self
let (_regex_compiled_grammar, _states_to_token_maps) = self
.compile_grammar(serde_json::to_string(&json).unwrap())
.await
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
@ -416,16 +415,12 @@ impl Validation {
end_states,
};
(
regex_compiled_grammar,
ProtoGrammarType::Regex.into(),
Some(stm),
)
Some(stm)
}
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into(), None),
GrammarType::Regex(_regex) => None,
}
}
None => (String::new(), ProtoGrammarType::None.into(), None),
None => None,
};
let parameters = NextTokenChooserParameters {
@ -438,8 +433,6 @@ impl Validation {
do_sample,
seed,
watermark,
grammar,
grammar_type,
states_to_token_maps,
};
let stopping_parameters = StoppingCriteriaParameters {
@ -567,7 +560,6 @@ fn compile_grammar(
r#"
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema
import time
from transformers.file_utils import SPIECE_UNDERLINE
class Tokenizer:
@ -589,13 +581,10 @@ class Tokenizer:
return " ".join(tokens)
def adapt_tokenizer(vocab, special_tokens):
start_time = time.time()
tokenizer = Tokenizer(vocab, special_tokens)
def convert_token_to_string(token: str) -> str:
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
@ -603,26 +592,16 @@ def adapt_tokenizer(vocab, special_tokens):
return string
tokenizer.convert_token_to_string = convert_token_to_string
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
return tokenizer
def compile_regex_grammar(inputs, vocab, special_tokens):
start_time = time.time()
print("🔥 starting compile_regex_grammar", inputs)
schema = build_regex_from_schema(inputs)
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
tokenizer = adapt_tokenizer(vocab, special_tokens)
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
fsm = RegexFSM(schema, tokenizer)
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
return fsm
def convert_grammar_to_regex(inputs):
start_time = time.time()
print("🔥 starting convert_grammar_to_regex", inputs)
schema = build_regex_from_schema(inputs)
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
return schema
return build_regex_from_schema(inputs)
"#,
"",
"",
@ -645,19 +624,7 @@ def convert_grammar_to_regex(inputs):
.getattr("states_to_token_maps")?
.extract::<StateTokenMaps>()?;
println!("🔥 elapsed: {:?}", start_time.elapsed());
// size of serialized states_to_token_maps
let states_to_token_maps_json = serde_json::to_string(&states_to_token_maps).unwrap();
println!(
"🔥 states_to_token_maps size: {:.2}MB",
states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0
);
let result = regex_fsm.into_ref(py).extract().unwrap();
println!("result: {:?}", result);
Ok((result, states_to_token_maps))
})
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;

View File

@ -475,22 +475,26 @@ class GrammarLogitProcessor(LogitsProcessor):
fsm_state: DefaultDict[int, int]
fsm: RegexFSM
def __init__(self, tokenizer, device, grammar, grammar_type, states_to_token_maps):
def __init__(self, tokenizer, device, states_to_token_maps):
self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
# TODO: use the precompiled grammar here
self.states_to_token_maps = states_to_token_maps
precompiled_grammar = RegexFSM.precompiled(
states_to_token_maps=states_to_token_maps,
empty_token_ids=None,
vocabulary=None,
eos_token_id=None,
)
start_states = states_to_token_maps.start_states
tokens = states_to_token_maps.tokens
end_states = states_to_token_maps.end_states
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type, grammar, self.tokenizer
)
_states_to_token_maps = {}
for i in range(len(start_states)):
if start_states[i] in _states_to_token_maps:
_states_to_token_maps[start_states[i]][tokens[i]] = end_states[i]
else:
_states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]}
fsm = object.__new__(RegexFSM)
fsm.states_to_token_maps = _states_to_token_maps
fsm.empty_token_ids = None
fsm.vocabulary = list(tokenizer.get_vocab().values())
fsm.eos_token_id = tokenizer.eos_token_id
self.fsm = fsm
def __call__(
self,
@ -560,17 +564,14 @@ class GrammarLogitProcessor(LogitsProcessor):
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__(
self, tokenizer, device, grammars, grammar_types, states_to_token_maps
):
def __init__(self, tokenizer, device, states_to_token_maps):
self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = []
for grammar, grammar_type in zip(grammars, grammar_types):
start_states = states_to_token_maps[0].start_states
tokens = states_to_token_maps[0].tokens
end_states = states_to_token_maps[0].end_states
for states_to_token_map in states_to_token_maps:
start_states = states_to_token_map.start_states
tokens = states_to_token_map.tokens
end_states = states_to_token_map.end_states
_states_to_token_maps = {}
for i in range(len(start_states)):
@ -579,18 +580,11 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
else:
_states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]}
# TODO: cleanup how precompiled grammars are handled
precompiled_grammar = RegexFSM.precompiled(
states_to_token_maps=_states_to_token_maps,
empty_token_ids=None,
vocabulary=list(tokenizer.get_vocab().values()),
eos_token_id=self.tokenizer.eos_token_id,
)
# fsm = GrammarLogitProcessor._cached_compile_fsm(
# grammar_type, grammar, self.tokenizer
# )
fsm = precompiled_grammar
fsm = object.__new__(RegexFSM)
fsm.states_to_token_maps = _states_to_token_maps
fsm.empty_token_ids = None
fsm.vocabulary = list(tokenizer.get_vocab().values())
fsm.eos_token_id = tokenizer.eos_token_id
self.fsms.append(fsm)
def __call__(

View File

@ -36,8 +36,6 @@ class NextTokenChooser:
seed: int = 0,
device: str = "cpu",
tokenizer: Optional[PreTrainedTokenizerBase] = None,
grammar: str = "",
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
fsm_grammar_state: int = 0,
states_to_token_maps: List[List[int]] = None,
):
@ -55,10 +53,8 @@ class NextTokenChooser:
else None
)
self.grammar_processor = (
GrammarLogitProcessor(
tokenizer, device, grammar, grammar_type, states_to_token_maps
)
if grammar != ""
GrammarLogitProcessor(tokenizer, device, states_to_token_maps)
if states_to_token_maps
else None
)
self.tokenizer = tokenizer
@ -80,7 +76,6 @@ class NextTokenChooser:
self.choice = Sampling(seed, device) if sampling else Greedy()
self.fsm_grammar_state = fsm_grammar_state
self.grammar = grammar
self.states_to_token_maps = states_to_token_maps
def __call__(self, input_ids, scores):
@ -128,8 +123,6 @@ class NextTokenChooser:
seed=pb.seed,
device=device,
tokenizer=tokenizer,
grammar=pb.grammar,
grammar_type=pb.grammar_type,
states_to_token_maps=pb.states_to_token_maps,
)
@ -236,8 +229,6 @@ class HeterogeneousNextTokenChooser:
do_sample: List[bool],
seeds: List[int],
tokenizer: PreTrainedTokenizerBase,
grammars: List[str],
grammar_types: List[int],
fsm_grammar_states: List[int],
states_to_token_maps: List[List[List[int]]],
):
@ -272,10 +263,8 @@ class HeterogeneousNextTokenChooser:
)
self.grammar_processor = (
HeterogeneousGrammarLogitProcessor(
tokenizer, device, grammars, grammar_types, states_to_token_maps
)
if any([grammar != "" for grammar in grammars])
HeterogeneousGrammarLogitProcessor(tokenizer, device, states_to_token_maps)
if any(states_to_token_maps)
else None
)
@ -312,8 +301,6 @@ class HeterogeneousNextTokenChooser:
self.device = device
self.tokenizer = tokenizer
self.fsm_grammar_states = fsm_grammar_states
self.grammars = grammars
self.grammar_types = grammar_types
self.states_to_token_maps = states_to_token_maps
def __call__(
@ -447,17 +434,11 @@ class HeterogeneousNextTokenChooser:
self.seeds = [self.seeds[i] for i in indices]
self.do_sample = [self.do_sample[i] for i in indices]
new_grammars = []
new_fsm_grammar_states = []
new_grammar_types = []
for i in indices:
new_grammars.append(self.grammars[i])
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
new_grammar_types.append(self.grammar_types[i])
self.grammars = new_grammars
self.fsm_grammar_states = new_fsm_grammar_states
self.grammar_types = new_grammar_types
if any(self.do_sample):
self.choice.filter(indices)
@ -488,8 +469,6 @@ class HeterogeneousNextTokenChooser:
device=device,
dtype=dtype,
tokenizer=tokenizer,
grammars=[pb_.grammar for pb_ in pb],
grammar_types=[pb_.grammar_type for pb_ in pb],
fsm_grammar_states=(
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
),