feat: remove states from proto and simplify logit processor

This commit is contained in:
drbh 2024-02-12 17:33:30 +00:00
parent a28ba7212c
commit 13e07b8257
9 changed files with 50 additions and 98 deletions

View File

@ -46,7 +46,6 @@ pub async fn run(
frequency_penalty: frequency_penalty.unwrap_or(0.0), frequency_penalty: frequency_penalty.unwrap_or(0.0),
watermark, watermark,
grammar: String::new(), grammar: String::new(),
fsm_grammar_state: 0,
}; };
// Initialize terminal properties // Initialize terminal properties

View File

@ -376,6 +376,7 @@ def generate_load():
n: int, n: int,
seed: Optional[int] = None, seed: Optional[int] = None,
grammar: Optional[str] = None, grammar: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
) -> List[Response]: ) -> List[Response]:
futures = [ futures = [
client.generate( client.generate(
@ -384,6 +385,7 @@ def generate_load():
decoder_input_details=True, decoder_input_details=True,
seed=seed, seed=seed,
grammar=grammar, grammar=grammar,
stop_sequences=stop_sequences,
) )
for _ in range(n) for _ in range(n)
] ]

View File

@ -72,8 +72,6 @@ message NextTokenChooserParameters {
bool watermark = 8; bool watermark = 8;
/// grammar (applied if not empty) /// grammar (applied if not empty)
string grammar = 10; string grammar = 10;
/// fsm_grammar_state
uint32 fsm_grammar_state = 11;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {

View File

@ -129,7 +129,6 @@ impl Client {
frequency_penalty: 0.1, frequency_penalty: 0.1,
watermark: true, watermark: true,
grammar: String::new(), grammar: String::new(),
fsm_grammar_state: 0,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate, max_new_tokens: max_total_tokens - truncate,

View File

@ -46,7 +46,6 @@ impl Health {
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: String::new(), grammar: String::new(),
fsm_grammar_state: 0,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1, max_new_tokens: 1,

View File

@ -369,7 +369,6 @@ mod tests {
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: String::new(), grammar: String::new(),
fsm_grammar_state: 0,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false, ignore_eos_token: false,

View File

@ -293,9 +293,6 @@ impl Validation {
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(request.inputs, truncate, max_new_tokens)
.await?; .await?;
// init the start state of the grammar
let fsm_grammar_state = 0;
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
repetition_penalty, repetition_penalty,
@ -307,7 +304,6 @@ impl Validation {
seed, seed,
watermark, watermark,
grammar, grammar,
fsm_grammar_state,
}; };
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,

View File

@ -1,6 +1,7 @@
import math import math
import torch import torch
import json
from loguru import logger from loguru import logger
from functools import lru_cache from functools import lru_cache
from typing import Optional, List, Dict, Union from typing import Optional, List, Dict, Union
@ -477,18 +478,19 @@ class GrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device): def __init__(self, tokenizer, device):
self.device = device self.device = device
self.tokenizer = tokenizer self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_state, grammar): def __call__(
if fsm_grammar_state == -1: self,
# todo mask for only eos token _input_ids: torch.Tensor,
logits: torch.Tensor,
fsm_grammar_state: int,
grammar: str,
):
if fsm_grammar_state == -1 or grammar == "":
return logits return logits
# if grammar is '' or None, return the greedy token fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer)
if grammar == "" or grammar is None:
return logits
fsm = self.compile_fsm(grammar, self.tokenizer)
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_state) allowed_tokens = fsm.allowed_token_ids(fsm_grammar_state)
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0 mask[allowed_tokens] = 0
@ -502,20 +504,24 @@ class GrammarLogitProcessor(LogitsProcessor):
if grammar == "" or grammar is None: if grammar == "" or grammar is None:
return fsm_grammar_state return fsm_grammar_state
fsm = self.compile_fsm(grammar, self.tokenizer) fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer)
return fsm.next_state(fsm_grammar_state, next_token_id) return fsm.next_state(fsm_grammar_state, next_token_id)
@staticmethod
@lru_cache(maxsize=32, typed=True) @lru_cache(maxsize=32, typed=True)
def compile_fsm(self, schema, tokenizer): def _cached_compile_fsm(self, schema, tokenizer):
start_time = time.time() start_time = time.time()
tokenizer = self.adapt_tokenizer(tokenizer) try:
is_json_string = schema.startswith("{") and schema.endswith("}") json.loads(schema) # check if schema is a valid json
regex_string = build_regex_from_object(schema) if is_json_string else schema schema = build_regex_from_object(schema) # convert schema to regex
fsm = RegexFSM(regex_string, tokenizer) except json.JSONDecodeError:
pass
fsm = RegexFSM(schema, tokenizer)
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s") logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm return fsm
def adapt_tokenizer(self, tokenizer): @staticmethod
def adapt_tokenizer(tokenizer):
"""Adapt tokenizer to work with the FSM. """Adapt tokenizer to work with the FSM.
The API of Outlines tokenizers is slightly different to that of The API of Outlines tokenizers is slightly different to that of
@ -548,68 +554,31 @@ class GrammarLogitProcessor(LogitsProcessor):
class HeterogeneousGrammarLogitProcessor(LogitsProcessor): class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device): def __init__(self, tokenizer, device):
self.device = device self.device = device
self.tokenizer = tokenizer self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_states, grammars): def __call__(
self,
_input_ids: torch.Tensor,
logits: torch.Tensor,
fsm_grammar_states: List[int],
grammars: List[str],
):
for i in range(logits.shape[0]): for i in range(logits.shape[0]):
if fsm_grammar_states[i] == -1: if fsm_grammar_states[i] == -1 or grammars[i] == "":
# todo mask for only eos token
continue continue
# if grammar is '' or None, return the greedy token fsm = GrammarLogitProcessor._cached_compile_fsm(
if grammars[i] == "" or grammars[i] is None: self, grammars[i], self.tokenizer
continue )
fsm = self.compile_fsm(grammars[i], self.tokenizer)
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0 mask[allowed_tokens] = 0
biased_scores = logits[i] + mask biased_scores = logits[i] + mask
logits[i] = biased_scores logits[i] = biased_scores
return logits return logits
def advance(self, next_token_id, fsm_grammar_state, grammar): def advance(self, next_token_ids, fsm_grammar_states, grammars):
if fsm_grammar_state == -1: return GrammarLogitProcessor.advance(
return fsm_grammar_state self, next_token_ids, fsm_grammar_states, grammars
)
if grammar == "" or grammar is None:
return fsm_grammar_state
fsm = self.compile_fsm(grammar, self.tokenizer)
return fsm.next_state(fsm_grammar_state, next_token_id)
@lru_cache(maxsize=32, typed=True)
def compile_fsm(self, schema, tokenizer):
start_time = time.time()
tokenizer = self.adapt_tokenizer(tokenizer)
is_json_string = schema.startswith("{") and schema.endswith("}")
regex_string = build_regex_from_object(schema) if is_json_string else schema
fsm = RegexFSM(regex_string, tokenizer)
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm
def adapt_tokenizer(self, tokenizer):
"""Adapt tokenizer to work with the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
tokenizer.convert_token_to_string = convert_token_to_string
return tokenizer

View File

@ -1,7 +1,5 @@
import re import re
from typing import Callable, List, Optional, Tuple, DefaultDict from typing import List, Optional, Tuple, DefaultDict
from collections import defaultdict
import math
import torch import torch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
@ -38,7 +36,7 @@ class NextTokenChooser:
device: str = "cpu", device: str = "cpu",
tokenizer: Optional[PreTrainedTokenizerBase] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None,
grammar: str = "", grammar: str = "",
fsm_grammar_state: Optional[DefaultDict[int, int]] = None, fsm_grammar_state: int = 0,
): ):
self.watermark_processor = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
@ -124,7 +122,6 @@ class NextTokenChooser:
device=device, device=device,
tokenizer=tokenizer, tokenizer=tokenizer,
grammar=pb.grammar, grammar=pb.grammar,
fsm_grammar_state=pb.fsm_grammar_state,
) )
@ -229,9 +226,9 @@ class HeterogeneousNextTokenChooser:
typical_p: List[float], typical_p: List[float],
do_sample: List[bool], do_sample: List[bool],
seeds: List[int], seeds: List[int],
tokenizer=None, tokenizer: PreTrainedTokenizerBase,
grammars=None, grammars: List[str],
fsm_grammar_states=None, fsm_grammar_states=List[int],
): ):
warpers = [] warpers = []
@ -395,15 +392,9 @@ class HeterogeneousNextTokenChooser:
# advance the grammar state # advance the grammar state
if self.grammar_processor is not None: if self.grammar_processor is not None:
for i in range(len(self.fsm_grammar_states)): for i in range(len(self.fsm_grammar_states)):
try:
self.fsm_grammar_states[i] = self.grammar_processor.advance( self.fsm_grammar_states[i] = self.grammar_processor.advance(
next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i] next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i]
) )
except:
import ipdb
ipdb.set_trace()
pass
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
@ -465,7 +456,7 @@ class HeterogeneousNextTokenChooser:
dtype=dtype, dtype=dtype,
tokenizer=tokenizer, tokenizer=tokenizer,
grammars=[pb_.grammar for pb_ in pb], grammars=[pb_.grammar for pb_ in pb],
fsm_grammar_states=[pb_.fsm_grammar_state for pb_ in pb], fsm_grammar_states=[0] * len(pb),
) )