feat: remove uncompile grammar and improve logit processor logic

This commit is contained in:
drbh 2024-03-08 03:34:19 +00:00
parent c52a0f679e
commit 1f7be736d2
10 changed files with 44 additions and 123 deletions

View File

@ -250,5 +250,7 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image # Final image
FROM base FROM base
ENV LD_LIBRARY_PATH=/opt/conda/lib/:$LD_LIBRARY_PATH
ENTRYPOINT ["text-generation-launcher"] ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"] CMD ["--json-output"]

View File

@ -8,7 +8,7 @@ use crate::app::App;
use crate::event::Event; use crate::event::Event;
use crossterm::ExecutableCommand; use crossterm::ExecutableCommand;
use std::io; use std::io;
use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient}; use text_generation_client::{NextTokenChooserParameters, ShardedClient};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::sync::{broadcast, mpsc}; use tokio::sync::{broadcast, mpsc};
use tui::backend::CrosstermBackend; use tui::backend::CrosstermBackend;
@ -45,8 +45,6 @@ pub async fn run(
repetition_penalty: repetition_penalty.unwrap_or(1.0), repetition_penalty: repetition_penalty.unwrap_or(1.0),
frequency_penalty: frequency_penalty.unwrap_or(0.0), frequency_penalty: frequency_penalty.unwrap_or(0.0),
watermark, watermark,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
states_to_token_maps: None, states_to_token_maps: None,
}; };

View File

@ -51,12 +51,6 @@ message ClearCacheRequest {
/// Empty response /// Empty response
message ClearCacheResponse {} message ClearCacheResponse {}
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2;
}
message NextTokenChooserParameters { message NextTokenChooserParameters {
/// exponential scaling output probability distribution /// exponential scaling output probability distribution
float temperature = 1; float temperature = 1;
@ -76,10 +70,6 @@ message NextTokenChooserParameters {
float frequency_penalty = 9; float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models" /// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8; bool watermark = 8;
/// grammar (applied if not empty)
string grammar = 10;
/// grammar type
GrammarType grammar_type = 11;
/// states to token maps /// states to token maps
StatesToTokenMaps states_to_token_maps = 12; StatesToTokenMaps states_to_token_maps = 12;
} }

View File

@ -135,8 +135,6 @@ impl Client {
repetition_penalty: 1.2, repetition_penalty: 1.2,
frequency_penalty: 0.1, frequency_penalty: 0.1,
watermark: true, watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
states_to_token_maps: None, states_to_token_maps: None,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {

View File

@ -9,8 +9,8 @@ pub use client::Client;
pub use pb::generate::v2::HealthResponse; pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo; pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{ pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
NextTokenChooserParameters, Request, StatesToTokenMaps, StoppingCriteriaParameters, Tokens, Request, StatesToTokenMaps, StoppingCriteriaParameters, Tokens,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;

View File

@ -1,6 +1,5 @@
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::GrammarType as ProtoGrammarType;
use text_generation_client::{ use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
}; };
@ -46,8 +45,6 @@ impl Health {
repetition_penalty: 1.0, repetition_penalty: 1.0,
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
states_to_token_maps: None, states_to_token_maps: None,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {

View File

@ -343,9 +343,7 @@ enum QueueCommand {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use text_generation_client::{ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
use tracing::info_span; use tracing::info_span;
fn default_entry() -> ( fn default_entry() -> (
@ -370,8 +368,6 @@ mod tests {
repetition_penalty: 0.0, repetition_penalty: 0.0,
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
states_to_token_maps: None, states_to_token_maps: None,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {

View File

@ -7,8 +7,7 @@ use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use text_generation_client::{ use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StatesToTokenMaps, NextTokenChooserParameters, StatesToTokenMaps, StoppingCriteriaParameters,
StoppingCriteriaParameters,
}; };
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
@ -368,7 +367,7 @@ impl Validation {
// compiler and use that to build the FSM here. // compiler and use that to build the FSM here.
// Validate grammar and unpack the grammar and type for the proto message // Validate grammar and unpack the grammar and type for the proto message
let (grammar, grammar_type, states_to_token_maps) = match grammar { let states_to_token_maps = match grammar {
Some(grammar) => { Some(grammar) => {
// Ensure that grammar is not set if it's not supported // Ensure that grammar is not set if it's not supported
if self.disable_grammar_support { if self.disable_grammar_support {
@ -392,7 +391,7 @@ impl Validation {
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
// NOTE: this is the first step to compile the grammar // NOTE: this is the first step to compile the grammar
let (regex_compiled_grammar, _states_to_token_maps) = self let (_regex_compiled_grammar, _states_to_token_maps) = self
.compile_grammar(serde_json::to_string(&json).unwrap()) .compile_grammar(serde_json::to_string(&json).unwrap())
.await .await
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
@ -416,16 +415,12 @@ impl Validation {
end_states, end_states,
}; };
( Some(stm)
regex_compiled_grammar,
ProtoGrammarType::Regex.into(),
Some(stm),
)
} }
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into(), None), GrammarType::Regex(_regex) => None,
} }
} }
None => (String::new(), ProtoGrammarType::None.into(), None), None => None,
}; };
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
@ -438,8 +433,6 @@ impl Validation {
do_sample, do_sample,
seed, seed,
watermark, watermark,
grammar,
grammar_type,
states_to_token_maps, states_to_token_maps,
}; };
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = StoppingCriteriaParameters {
@ -567,7 +560,6 @@ fn compile_grammar(
r#" r#"
from outlines.fsm.fsm import RegexFSM from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema from outlines.fsm.json_schema import build_regex_from_schema
import time
from transformers.file_utils import SPIECE_UNDERLINE from transformers.file_utils import SPIECE_UNDERLINE
class Tokenizer: class Tokenizer:
@ -589,13 +581,10 @@ class Tokenizer:
return " ".join(tokens) return " ".join(tokens)
def adapt_tokenizer(vocab, special_tokens): def adapt_tokenizer(vocab, special_tokens):
start_time = time.time()
tokenizer = Tokenizer(vocab, special_tokens) tokenizer = Tokenizer(vocab, special_tokens)
def convert_token_to_string(token: str) -> str: def convert_token_to_string(token: str) -> str:
string = tokenizer.convert_tokens_to_string([token]) string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers # A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string return " " + string
@ -603,26 +592,16 @@ def adapt_tokenizer(vocab, special_tokens):
return string return string
tokenizer.convert_token_to_string = convert_token_to_string tokenizer.convert_token_to_string = convert_token_to_string
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
return tokenizer return tokenizer
def compile_regex_grammar(inputs, vocab, special_tokens): def compile_regex_grammar(inputs, vocab, special_tokens):
start_time = time.time()
print("🔥 starting compile_regex_grammar", inputs)
schema = build_regex_from_schema(inputs) schema = build_regex_from_schema(inputs)
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
tokenizer = adapt_tokenizer(vocab, special_tokens) tokenizer = adapt_tokenizer(vocab, special_tokens)
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
fsm = RegexFSM(schema, tokenizer) fsm = RegexFSM(schema, tokenizer)
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
return fsm return fsm
def convert_grammar_to_regex(inputs): def convert_grammar_to_regex(inputs):
start_time = time.time() return build_regex_from_schema(inputs)
print("🔥 starting convert_grammar_to_regex", inputs)
schema = build_regex_from_schema(inputs)
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
return schema
"#, "#,
"", "",
"", "",
@ -645,19 +624,7 @@ def convert_grammar_to_regex(inputs):
.getattr("states_to_token_maps")? .getattr("states_to_token_maps")?
.extract::<StateTokenMaps>()?; .extract::<StateTokenMaps>()?;
println!("🔥 elapsed: {:?}", start_time.elapsed());
// size of serialized states_to_token_maps
let states_to_token_maps_json = serde_json::to_string(&states_to_token_maps).unwrap();
println!(
"🔥 states_to_token_maps size: {:.2}MB",
states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0
);
let result = regex_fsm.into_ref(py).extract().unwrap(); let result = regex_fsm.into_ref(py).extract().unwrap();
println!("result: {:?}", result);
Ok((result, states_to_token_maps)) Ok((result, states_to_token_maps))
}) })
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;

View File

@ -475,22 +475,26 @@ class GrammarLogitProcessor(LogitsProcessor):
fsm_state: DefaultDict[int, int] fsm_state: DefaultDict[int, int]
fsm: RegexFSM fsm: RegexFSM
def __init__(self, tokenizer, device, grammar, grammar_type, states_to_token_maps): def __init__(self, tokenizer, device, states_to_token_maps):
self.device = device self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
# TODO: use the precompiled grammar here start_states = states_to_token_maps.start_states
self.states_to_token_maps = states_to_token_maps tokens = states_to_token_maps.tokens
precompiled_grammar = RegexFSM.precompiled( end_states = states_to_token_maps.end_states
states_to_token_maps=states_to_token_maps,
empty_token_ids=None,
vocabulary=None,
eos_token_id=None,
)
self.fsm = GrammarLogitProcessor._cached_compile_fsm( _states_to_token_maps = {}
grammar_type, grammar, self.tokenizer for i in range(len(start_states)):
) if start_states[i] in _states_to_token_maps:
_states_to_token_maps[start_states[i]][tokens[i]] = end_states[i]
else:
_states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]}
fsm = object.__new__(RegexFSM)
fsm.states_to_token_maps = _states_to_token_maps
fsm.empty_token_ids = None
fsm.vocabulary = list(tokenizer.get_vocab().values())
fsm.eos_token_id = tokenizer.eos_token_id
self.fsm = fsm
def __call__( def __call__(
self, self,
@ -560,17 +564,14 @@ class GrammarLogitProcessor(LogitsProcessor):
class HeterogeneousGrammarLogitProcessor(LogitsProcessor): class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__( def __init__(self, tokenizer, device, states_to_token_maps):
self, tokenizer, device, grammars, grammar_types, states_to_token_maps
):
self.device = device self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = [] self.fsms = []
for grammar, grammar_type in zip(grammars, grammar_types): for states_to_token_map in states_to_token_maps:
start_states = states_to_token_maps[0].start_states start_states = states_to_token_map.start_states
tokens = states_to_token_maps[0].tokens tokens = states_to_token_map.tokens
end_states = states_to_token_maps[0].end_states end_states = states_to_token_map.end_states
_states_to_token_maps = {} _states_to_token_maps = {}
for i in range(len(start_states)): for i in range(len(start_states)):
@ -579,18 +580,11 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
else: else:
_states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]} _states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]}
# TODO: cleanup how precompiled grammars are handled fsm = object.__new__(RegexFSM)
precompiled_grammar = RegexFSM.precompiled( fsm.states_to_token_maps = _states_to_token_maps
states_to_token_maps=_states_to_token_maps, fsm.empty_token_ids = None
empty_token_ids=None, fsm.vocabulary = list(tokenizer.get_vocab().values())
vocabulary=list(tokenizer.get_vocab().values()), fsm.eos_token_id = tokenizer.eos_token_id
eos_token_id=self.tokenizer.eos_token_id,
)
# fsm = GrammarLogitProcessor._cached_compile_fsm(
# grammar_type, grammar, self.tokenizer
# )
fsm = precompiled_grammar
self.fsms.append(fsm) self.fsms.append(fsm)
def __call__( def __call__(

View File

@ -36,8 +36,6 @@ class NextTokenChooser:
seed: int = 0, seed: int = 0,
device: str = "cpu", device: str = "cpu",
tokenizer: Optional[PreTrainedTokenizerBase] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None,
grammar: str = "",
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
fsm_grammar_state: int = 0, fsm_grammar_state: int = 0,
states_to_token_maps: List[List[int]] = None, states_to_token_maps: List[List[int]] = None,
): ):
@ -55,10 +53,8 @@ class NextTokenChooser:
else None else None
) )
self.grammar_processor = ( self.grammar_processor = (
GrammarLogitProcessor( GrammarLogitProcessor(tokenizer, device, states_to_token_maps)
tokenizer, device, grammar, grammar_type, states_to_token_maps if states_to_token_maps
)
if grammar != ""
else None else None
) )
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -80,7 +76,6 @@ class NextTokenChooser:
self.choice = Sampling(seed, device) if sampling else Greedy() self.choice = Sampling(seed, device) if sampling else Greedy()
self.fsm_grammar_state = fsm_grammar_state self.fsm_grammar_state = fsm_grammar_state
self.grammar = grammar
self.states_to_token_maps = states_to_token_maps self.states_to_token_maps = states_to_token_maps
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
@ -128,8 +123,6 @@ class NextTokenChooser:
seed=pb.seed, seed=pb.seed,
device=device, device=device,
tokenizer=tokenizer, tokenizer=tokenizer,
grammar=pb.grammar,
grammar_type=pb.grammar_type,
states_to_token_maps=pb.states_to_token_maps, states_to_token_maps=pb.states_to_token_maps,
) )
@ -236,8 +229,6 @@ class HeterogeneousNextTokenChooser:
do_sample: List[bool], do_sample: List[bool],
seeds: List[int], seeds: List[int],
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
grammars: List[str],
grammar_types: List[int],
fsm_grammar_states: List[int], fsm_grammar_states: List[int],
states_to_token_maps: List[List[List[int]]], states_to_token_maps: List[List[List[int]]],
): ):
@ -272,10 +263,8 @@ class HeterogeneousNextTokenChooser:
) )
self.grammar_processor = ( self.grammar_processor = (
HeterogeneousGrammarLogitProcessor( HeterogeneousGrammarLogitProcessor(tokenizer, device, states_to_token_maps)
tokenizer, device, grammars, grammar_types, states_to_token_maps if any(states_to_token_maps)
)
if any([grammar != "" for grammar in grammars])
else None else None
) )
@ -312,8 +301,6 @@ class HeterogeneousNextTokenChooser:
self.device = device self.device = device
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.fsm_grammar_states = fsm_grammar_states self.fsm_grammar_states = fsm_grammar_states
self.grammars = grammars
self.grammar_types = grammar_types
self.states_to_token_maps = states_to_token_maps self.states_to_token_maps = states_to_token_maps
def __call__( def __call__(
@ -447,17 +434,11 @@ class HeterogeneousNextTokenChooser:
self.seeds = [self.seeds[i] for i in indices] self.seeds = [self.seeds[i] for i in indices]
self.do_sample = [self.do_sample[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices]
new_grammars = []
new_fsm_grammar_states = [] new_fsm_grammar_states = []
new_grammar_types = []
for i in indices: for i in indices:
new_grammars.append(self.grammars[i])
new_fsm_grammar_states.append(self.fsm_grammar_states[i]) new_fsm_grammar_states.append(self.fsm_grammar_states[i])
new_grammar_types.append(self.grammar_types[i])
self.grammars = new_grammars
self.fsm_grammar_states = new_fsm_grammar_states self.fsm_grammar_states = new_fsm_grammar_states
self.grammar_types = new_grammar_types
if any(self.do_sample): if any(self.do_sample):
self.choice.filter(indices) self.choice.filter(indices)
@ -488,8 +469,6 @@ class HeterogeneousNextTokenChooser:
device=device, device=device,
dtype=dtype, dtype=dtype,
tokenizer=tokenizer, tokenizer=tokenizer,
grammars=[pb_.grammar for pb_ in pb],
grammar_types=[pb_.grammar_type for pb_ in pb],
fsm_grammar_states=( fsm_grammar_states=(
fsm_grammar_states if fsm_grammar_states else [0] * len(pb) fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
), ),