From 13e07b8257950278467b452ad20b6ddc2f76c5d4 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 12 Feb 2024 17:33:30 +0000 Subject: [PATCH] feat: remove states from proto and simplify logit processor --- benchmark/src/lib.rs | 1 - integration-tests/conftest.py | 2 + proto/generate.proto | 2 - router/client/src/client.rs | 1 - router/src/health.rs | 1 - router/src/queue.rs | 1 - router/src/validation.rs | 4 - .../utils/logits_process.py | 109 +++++++----------- server/text_generation_server/utils/tokens.py | 27 ++--- 9 files changed, 50 insertions(+), 98 deletions(-) diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index f38117b2..595e0cb9 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -46,7 +46,6 @@ pub async fn run( frequency_penalty: frequency_penalty.unwrap_or(0.0), watermark, grammar: String::new(), - fsm_grammar_state: 0, }; // Initialize terminal properties diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index f6c239f2..6fd3365d 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -376,6 +376,7 @@ def generate_load(): n: int, seed: Optional[int] = None, grammar: Optional[str] = None, + stop_sequences: Optional[List[str]] = None, ) -> List[Response]: futures = [ client.generate( @@ -384,6 +385,7 @@ def generate_load(): decoder_input_details=True, seed=seed, grammar=grammar, + stop_sequences=stop_sequences, ) for _ in range(n) ] diff --git a/proto/generate.proto b/proto/generate.proto index 82081921..aae0e7a4 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -72,8 +72,6 @@ message NextTokenChooserParameters { bool watermark = 8; /// grammar (applied if not empty) string grammar = 10; - /// fsm_grammar_state - uint32 fsm_grammar_state = 11; } message StoppingCriteriaParameters { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 38e6e0e3..9822ea77 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -129,7 +129,6 @@ impl Client { frequency_penalty: 0.1, watermark: true, grammar: String::new(), - fsm_grammar_state: 0, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: max_total_tokens - truncate, diff --git a/router/src/health.rs b/router/src/health.rs index f3cac17e..6f3d2023 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -46,7 +46,6 @@ impl Health { frequency_penalty: 0.0, watermark: false, grammar: String::new(), - fsm_grammar_state: 0, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/router/src/queue.rs b/router/src/queue.rs index 0162b906..3e4aefa1 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -369,7 +369,6 @@ mod tests { frequency_penalty: 0.0, watermark: false, grammar: String::new(), - fsm_grammar_state: 0, }, stopping_parameters: StoppingCriteriaParameters { ignore_eos_token: false, diff --git a/router/src/validation.rs b/router/src/validation.rs index 4c38db68..a77995df 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -293,9 +293,6 @@ impl Validation { .validate_input(request.inputs, truncate, max_new_tokens) .await?; - // init the start state of the grammar - let fsm_grammar_state = 0; - let parameters = NextTokenChooserParameters { temperature, repetition_penalty, @@ -307,7 +304,6 @@ impl Validation { seed, watermark, grammar, - fsm_grammar_state, }; let stopping_parameters = StoppingCriteriaParameters { max_new_tokens, diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 97086928..cc9a39ac 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -1,6 +1,7 @@ import math import torch +import json from loguru import logger from functools import lru_cache from typing import Optional, List, Dict, Union @@ -477,18 +478,19 @@ class GrammarLogitProcessor(LogitsProcessor): def __init__(self, tokenizer, device): self.device = device - self.tokenizer = tokenizer + self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer) - def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_state, grammar): - if fsm_grammar_state == -1: - # todo mask for only eos token + def __call__( + self, + _input_ids: torch.Tensor, + logits: torch.Tensor, + fsm_grammar_state: int, + grammar: str, + ): + if fsm_grammar_state == -1 or grammar == "": return logits - # if grammar is '' or None, return the greedy token - if grammar == "" or grammar is None: - return logits - - fsm = self.compile_fsm(grammar, self.tokenizer) + fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer) allowed_tokens = fsm.allowed_token_ids(fsm_grammar_state) mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) mask[allowed_tokens] = 0 @@ -502,20 +504,24 @@ class GrammarLogitProcessor(LogitsProcessor): if grammar == "" or grammar is None: return fsm_grammar_state - fsm = self.compile_fsm(grammar, self.tokenizer) + fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer) return fsm.next_state(fsm_grammar_state, next_token_id) + @staticmethod @lru_cache(maxsize=32, typed=True) - def compile_fsm(self, schema, tokenizer): + def _cached_compile_fsm(self, schema, tokenizer): start_time = time.time() - tokenizer = self.adapt_tokenizer(tokenizer) - is_json_string = schema.startswith("{") and schema.endswith("}") - regex_string = build_regex_from_object(schema) if is_json_string else schema - fsm = RegexFSM(regex_string, tokenizer) + try: + json.loads(schema) # check if schema is a valid json + schema = build_regex_from_object(schema) # convert schema to regex + except json.JSONDecodeError: + pass + fsm = RegexFSM(schema, tokenizer) logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s") return fsm - def adapt_tokenizer(self, tokenizer): + @staticmethod + def adapt_tokenizer(tokenizer): """Adapt tokenizer to work with the FSM. The API of Outlines tokenizers is slightly different to that of @@ -548,68 +554,31 @@ class GrammarLogitProcessor(LogitsProcessor): class HeterogeneousGrammarLogitProcessor(LogitsProcessor): def __init__(self, tokenizer, device): self.device = device - self.tokenizer = tokenizer + self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer) - def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_states, grammars): + def __call__( + self, + _input_ids: torch.Tensor, + logits: torch.Tensor, + fsm_grammar_states: List[int], + grammars: List[str], + ): for i in range(logits.shape[0]): - if fsm_grammar_states[i] == -1: - # todo mask for only eos token + if fsm_grammar_states[i] == -1 or grammars[i] == "": continue - # if grammar is '' or None, return the greedy token - if grammars[i] == "" or grammars[i] is None: - continue - - fsm = self.compile_fsm(grammars[i], self.tokenizer) + fsm = GrammarLogitProcessor._cached_compile_fsm( + self, grammars[i], self.tokenizer + ) allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) + mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) mask[allowed_tokens] = 0 biased_scores = logits[i] + mask logits[i] = biased_scores return logits - def advance(self, next_token_id, fsm_grammar_state, grammar): - if fsm_grammar_state == -1: - return fsm_grammar_state - - if grammar == "" or grammar is None: - return fsm_grammar_state - - fsm = self.compile_fsm(grammar, self.tokenizer) - return fsm.next_state(fsm_grammar_state, next_token_id) - - @lru_cache(maxsize=32, typed=True) - def compile_fsm(self, schema, tokenizer): - start_time = time.time() - tokenizer = self.adapt_tokenizer(tokenizer) - is_json_string = schema.startswith("{") and schema.endswith("}") - regex_string = build_regex_from_object(schema) if is_json_string else schema - fsm = RegexFSM(regex_string, tokenizer) - logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s") - return fsm - - def adapt_tokenizer(self, tokenizer): - """Adapt tokenizer to work with the FSM. - - The API of Outlines tokenizers is slightly different to that of - `transformers`. In addition we need to handle the missing spaces to - Llama's tokenizer to be able to compile FSMs for this model. - - """ - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) - - def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = tokenizer.convert_tokens_to_string([token]) - - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - tokenizer.convert_token_to_string = convert_token_to_string - - return tokenizer + def advance(self, next_token_ids, fsm_grammar_states, grammars): + return GrammarLogitProcessor.advance( + self, next_token_ids, fsm_grammar_states, grammars + ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index b056a522..6b1adce8 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,7 +1,5 @@ import re -from typing import Callable, List, Optional, Tuple, DefaultDict -from collections import defaultdict -import math +from typing import List, Optional, Tuple, DefaultDict import torch from text_generation_server.pb import generate_pb2 @@ -38,7 +36,7 @@ class NextTokenChooser: device: str = "cpu", tokenizer: Optional[PreTrainedTokenizerBase] = None, grammar: str = "", - fsm_grammar_state: Optional[DefaultDict[int, int]] = None, + fsm_grammar_state: int = 0, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -124,7 +122,6 @@ class NextTokenChooser: device=device, tokenizer=tokenizer, grammar=pb.grammar, - fsm_grammar_state=pb.fsm_grammar_state, ) @@ -229,9 +226,9 @@ class HeterogeneousNextTokenChooser: typical_p: List[float], do_sample: List[bool], seeds: List[int], - tokenizer=None, - grammars=None, - fsm_grammar_states=None, + tokenizer: PreTrainedTokenizerBase, + grammars: List[str], + fsm_grammar_states=List[int], ): warpers = [] @@ -395,15 +392,9 @@ class HeterogeneousNextTokenChooser: # advance the grammar state if self.grammar_processor is not None: for i in range(len(self.fsm_grammar_states)): - try: - self.fsm_grammar_states[i] = self.grammar_processor.advance( - next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i] - ) - except: - import ipdb - - ipdb.set_trace() - pass + self.fsm_grammar_states[i] = self.grammar_processor.advance( + next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i] + ) return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids @@ -465,7 +456,7 @@ class HeterogeneousNextTokenChooser: dtype=dtype, tokenizer=tokenizer, grammars=[pb_.grammar for pb_ in pb], - fsm_grammar_states=[pb_.fsm_grammar_state for pb_ in pb], + fsm_grammar_states=[0] * len(pb), )