diff --git a/Dockerfile b/Dockerfile index e79372a3..0b654598 100644 --- a/Dockerfile +++ b/Dockerfile @@ -250,5 +250,7 @@ ENTRYPOINT ["./entrypoint.sh"] # Final image FROM base +ENV LD_LIBRARY_PATH=/opt/conda/lib/:$LD_LIBRARY_PATH + ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index 2d545ba6..58ea90b1 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -8,7 +8,7 @@ use crate::app::App; use crate::event::Event; use crossterm::ExecutableCommand; use std::io; -use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient}; +use text_generation_client::{NextTokenChooserParameters, ShardedClient}; use tokenizers::Tokenizer; use tokio::sync::{broadcast, mpsc}; use tui::backend::CrosstermBackend; @@ -45,8 +45,6 @@ pub async fn run( repetition_penalty: repetition_penalty.unwrap_or(1.0), frequency_penalty: frequency_penalty.unwrap_or(0.0), watermark, - grammar: String::new(), - grammar_type: GrammarType::None as i32, states_to_token_maps: None, }; diff --git a/proto/generate.proto b/proto/generate.proto index 11c3fef5..b51d02e1 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -51,12 +51,6 @@ message ClearCacheRequest { /// Empty response message ClearCacheResponse {} -enum GrammarType { - GRAMMAR_TYPE_NONE = 0; - GRAMMAR_TYPE_JSON = 1; - GRAMMAR_TYPE_REGEX = 2; -} - message NextTokenChooserParameters { /// exponential scaling output probability distribution float temperature = 1; @@ -76,10 +70,6 @@ message NextTokenChooserParameters { float frequency_penalty = 9; /// token watermarking using "A Watermark for Large Language Models" bool watermark = 8; - /// grammar (applied if not empty) - string grammar = 10; - /// grammar type - GrammarType grammar_type = 11; /// states to token maps StatesToTokenMaps states_to_token_maps = 12; } diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 2fc06630..eb5e544b 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -135,8 +135,6 @@ impl Client { repetition_penalty: 1.2, frequency_penalty: 0.1, watermark: true, - grammar: String::new(), - grammar_type: GrammarType::None as i32, states_to_token_maps: None, }), stopping_parameters: Some(StoppingCriteriaParameters { diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index d613362f..6333021a 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -9,8 +9,8 @@ pub use client::Client; pub use pb::generate::v2::HealthResponse; pub use pb::generate::v2::InfoResponse as ShardInfo; pub use pb::generate::v2::{ - Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, - NextTokenChooserParameters, Request, StatesToTokenMaps, StoppingCriteriaParameters, Tokens, + Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, + Request, StatesToTokenMaps, StoppingCriteriaParameters, Tokens, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/src/health.rs b/router/src/health.rs index 140367e8..5de93b74 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -1,6 +1,5 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use text_generation_client::GrammarType as ProtoGrammarType; use text_generation_client::{ Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, }; @@ -46,8 +45,6 @@ impl Health { repetition_penalty: 1.0, frequency_penalty: 0.0, watermark: false, - grammar: String::new(), - grammar_type: ProtoGrammarType::None as i32, states_to_token_maps: None, }), stopping_parameters: Some(StoppingCriteriaParameters { diff --git a/router/src/queue.rs b/router/src/queue.rs index dd3c7884..76b17991 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -343,9 +343,7 @@ enum QueueCommand { #[cfg(test)] mod tests { use super::*; - use text_generation_client::{ - GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, - }; + use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use tracing::info_span; fn default_entry() -> ( @@ -370,8 +368,6 @@ mod tests { repetition_penalty: 0.0, frequency_penalty: 0.0, watermark: false, - grammar: String::new(), - grammar_type: ProtoGrammarType::None as i32, states_to_token_maps: None, }, stopping_parameters: StoppingCriteriaParameters { diff --git a/router/src/validation.rs b/router/src/validation.rs index b16b5048..bd0878dd 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -7,8 +7,7 @@ use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; use serde_json::Value; use text_generation_client::{ - GrammarType as ProtoGrammarType, NextTokenChooserParameters, StatesToTokenMaps, - StoppingCriteriaParameters, + NextTokenChooserParameters, StatesToTokenMaps, StoppingCriteriaParameters, }; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; @@ -368,7 +367,7 @@ impl Validation { // compiler and use that to build the FSM here. // Validate grammar and unpack the grammar and type for the proto message - let (grammar, grammar_type, states_to_token_maps) = match grammar { + let states_to_token_maps = match grammar { Some(grammar) => { // Ensure that grammar is not set if it's not supported if self.disable_grammar_support { @@ -392,7 +391,7 @@ impl Validation { .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; // NOTE: this is the first step to compile the grammar - let (regex_compiled_grammar, _states_to_token_maps) = self + let (_regex_compiled_grammar, _states_to_token_maps) = self .compile_grammar(serde_json::to_string(&json).unwrap()) .await .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; @@ -416,16 +415,12 @@ impl Validation { end_states, }; - ( - regex_compiled_grammar, - ProtoGrammarType::Regex.into(), - Some(stm), - ) + Some(stm) } - GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into(), None), + GrammarType::Regex(_regex) => None, } } - None => (String::new(), ProtoGrammarType::None.into(), None), + None => None, }; let parameters = NextTokenChooserParameters { @@ -438,8 +433,6 @@ impl Validation { do_sample, seed, watermark, - grammar, - grammar_type, states_to_token_maps, }; let stopping_parameters = StoppingCriteriaParameters { @@ -567,7 +560,6 @@ fn compile_grammar( r#" from outlines.fsm.fsm import RegexFSM from outlines.fsm.json_schema import build_regex_from_schema -import time from transformers.file_utils import SPIECE_UNDERLINE class Tokenizer: @@ -589,13 +581,10 @@ class Tokenizer: return " ".join(tokens) def adapt_tokenizer(vocab, special_tokens): - start_time = time.time() tokenizer = Tokenizer(vocab, special_tokens) def convert_token_to_string(token: str) -> str: - string = tokenizer.convert_tokens_to_string([token]) - # A hack to handle missing spaces to HF's Llama tokenizers if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": return " " + string @@ -603,26 +592,16 @@ def adapt_tokenizer(vocab, special_tokens): return string tokenizer.convert_token_to_string = convert_token_to_string - print(f"Adapted tokenizer in {time.time() - start_time:.2f}s") return tokenizer def compile_regex_grammar(inputs, vocab, special_tokens): - start_time = time.time() - print("🔥 starting compile_regex_grammar", inputs) schema = build_regex_from_schema(inputs) - print(f"Compiled grammar in {time.time() - start_time:.2f}s") tokenizer = adapt_tokenizer(vocab, special_tokens) - print(f"Adapted tokenizer in {time.time() - start_time:.2f}s") fsm = RegexFSM(schema, tokenizer) - print(f"Compiled grammar in {time.time() - start_time:.2f}s") return fsm def convert_grammar_to_regex(inputs): - start_time = time.time() - print("🔥 starting convert_grammar_to_regex", inputs) - schema = build_regex_from_schema(inputs) - print(f"Compiled grammar in {time.time() - start_time:.2f}s") - return schema + return build_regex_from_schema(inputs) "#, "", "", @@ -645,19 +624,7 @@ def convert_grammar_to_regex(inputs): .getattr("states_to_token_maps")? .extract::()?; - println!("🔥 elapsed: {:?}", start_time.elapsed()); - - // size of serialized states_to_token_maps - let states_to_token_maps_json = serde_json::to_string(&states_to_token_maps).unwrap(); - println!( - "🔥 states_to_token_maps size: {:.2}MB", - states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0 - ); - let result = regex_fsm.into_ref(py).extract().unwrap(); - - println!("result: {:?}", result); - Ok((result, states_to_token_maps)) }) .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index b3bfb3f3..7f9f0f45 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -475,22 +475,26 @@ class GrammarLogitProcessor(LogitsProcessor): fsm_state: DefaultDict[int, int] fsm: RegexFSM - def __init__(self, tokenizer, device, grammar, grammar_type, states_to_token_maps): + def __init__(self, tokenizer, device, states_to_token_maps): self.device = device - self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) - # TODO: use the precompiled grammar here - self.states_to_token_maps = states_to_token_maps - precompiled_grammar = RegexFSM.precompiled( - states_to_token_maps=states_to_token_maps, - empty_token_ids=None, - vocabulary=None, - eos_token_id=None, - ) + start_states = states_to_token_maps.start_states + tokens = states_to_token_maps.tokens + end_states = states_to_token_maps.end_states - self.fsm = GrammarLogitProcessor._cached_compile_fsm( - grammar_type, grammar, self.tokenizer - ) + _states_to_token_maps = {} + for i in range(len(start_states)): + if start_states[i] in _states_to_token_maps: + _states_to_token_maps[start_states[i]][tokens[i]] = end_states[i] + else: + _states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]} + + fsm = object.__new__(RegexFSM) + fsm.states_to_token_maps = _states_to_token_maps + fsm.empty_token_ids = None + fsm.vocabulary = list(tokenizer.get_vocab().values()) + fsm.eos_token_id = tokenizer.eos_token_id + self.fsm = fsm def __call__( self, @@ -560,17 +564,14 @@ class GrammarLogitProcessor(LogitsProcessor): class HeterogeneousGrammarLogitProcessor(LogitsProcessor): - def __init__( - self, tokenizer, device, grammars, grammar_types, states_to_token_maps - ): + def __init__(self, tokenizer, device, states_to_token_maps): self.device = device - self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.fsms = [] - for grammar, grammar_type in zip(grammars, grammar_types): - start_states = states_to_token_maps[0].start_states - tokens = states_to_token_maps[0].tokens - end_states = states_to_token_maps[0].end_states + for states_to_token_map in states_to_token_maps: + start_states = states_to_token_map.start_states + tokens = states_to_token_map.tokens + end_states = states_to_token_map.end_states _states_to_token_maps = {} for i in range(len(start_states)): @@ -579,18 +580,11 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): else: _states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]} - # TODO: cleanup how precompiled grammars are handled - precompiled_grammar = RegexFSM.precompiled( - states_to_token_maps=_states_to_token_maps, - empty_token_ids=None, - vocabulary=list(tokenizer.get_vocab().values()), - eos_token_id=self.tokenizer.eos_token_id, - ) - # fsm = GrammarLogitProcessor._cached_compile_fsm( - # grammar_type, grammar, self.tokenizer - # ) - - fsm = precompiled_grammar + fsm = object.__new__(RegexFSM) + fsm.states_to_token_maps = _states_to_token_maps + fsm.empty_token_ids = None + fsm.vocabulary = list(tokenizer.get_vocab().values()) + fsm.eos_token_id = tokenizer.eos_token_id self.fsms.append(fsm) def __call__( diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index ea8fed86..1a65bdb8 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -36,8 +36,6 @@ class NextTokenChooser: seed: int = 0, device: str = "cpu", tokenizer: Optional[PreTrainedTokenizerBase] = None, - grammar: str = "", - grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE, fsm_grammar_state: int = 0, states_to_token_maps: List[List[int]] = None, ): @@ -55,10 +53,8 @@ class NextTokenChooser: else None ) self.grammar_processor = ( - GrammarLogitProcessor( - tokenizer, device, grammar, grammar_type, states_to_token_maps - ) - if grammar != "" + GrammarLogitProcessor(tokenizer, device, states_to_token_maps) + if states_to_token_maps else None ) self.tokenizer = tokenizer @@ -80,7 +76,6 @@ class NextTokenChooser: self.choice = Sampling(seed, device) if sampling else Greedy() self.fsm_grammar_state = fsm_grammar_state - self.grammar = grammar self.states_to_token_maps = states_to_token_maps def __call__(self, input_ids, scores): @@ -128,8 +123,6 @@ class NextTokenChooser: seed=pb.seed, device=device, tokenizer=tokenizer, - grammar=pb.grammar, - grammar_type=pb.grammar_type, states_to_token_maps=pb.states_to_token_maps, ) @@ -236,8 +229,6 @@ class HeterogeneousNextTokenChooser: do_sample: List[bool], seeds: List[int], tokenizer: PreTrainedTokenizerBase, - grammars: List[str], - grammar_types: List[int], fsm_grammar_states: List[int], states_to_token_maps: List[List[List[int]]], ): @@ -272,10 +263,8 @@ class HeterogeneousNextTokenChooser: ) self.grammar_processor = ( - HeterogeneousGrammarLogitProcessor( - tokenizer, device, grammars, grammar_types, states_to_token_maps - ) - if any([grammar != "" for grammar in grammars]) + HeterogeneousGrammarLogitProcessor(tokenizer, device, states_to_token_maps) + if any(states_to_token_maps) else None ) @@ -312,8 +301,6 @@ class HeterogeneousNextTokenChooser: self.device = device self.tokenizer = tokenizer self.fsm_grammar_states = fsm_grammar_states - self.grammars = grammars - self.grammar_types = grammar_types self.states_to_token_maps = states_to_token_maps def __call__( @@ -447,17 +434,11 @@ class HeterogeneousNextTokenChooser: self.seeds = [self.seeds[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices] - new_grammars = [] new_fsm_grammar_states = [] - new_grammar_types = [] for i in indices: - new_grammars.append(self.grammars[i]) new_fsm_grammar_states.append(self.fsm_grammar_states[i]) - new_grammar_types.append(self.grammar_types[i]) - self.grammars = new_grammars self.fsm_grammar_states = new_fsm_grammar_states - self.grammar_types = new_grammar_types if any(self.do_sample): self.choice.filter(indices) @@ -488,8 +469,6 @@ class HeterogeneousNextTokenChooser: device=device, dtype=dtype, tokenizer=tokenizer, - grammars=[pb_.grammar for pb_ in pb], - grammar_types=[pb_.grammar_type for pb_ in pb], fsm_grammar_states=( fsm_grammar_states if fsm_grammar_states else [0] * len(pb) ),