feat: remove states from proto and simplify logit processor

This commit is contained in:
drbh 2024-02-12 17:33:30 +00:00
parent a28ba7212c
commit 13e07b8257
9 changed files with 50 additions and 98 deletions

View File

@ -46,7 +46,6 @@ pub async fn run(
frequency_penalty: frequency_penalty.unwrap_or(0.0),
watermark,
grammar: String::new(),
fsm_grammar_state: 0,
};
// Initialize terminal properties

View File

@ -376,6 +376,7 @@ def generate_load():
n: int,
seed: Optional[int] = None,
grammar: Optional[str] = None,
stop_sequences: Optional[List[str]] = None,
) -> List[Response]:
futures = [
client.generate(
@ -384,6 +385,7 @@ def generate_load():
decoder_input_details=True,
seed=seed,
grammar=grammar,
stop_sequences=stop_sequences,
)
for _ in range(n)
]

View File

@ -72,8 +72,6 @@ message NextTokenChooserParameters {
bool watermark = 8;
/// grammar (applied if not empty)
string grammar = 10;
/// fsm_grammar_state
uint32 fsm_grammar_state = 11;
}
message StoppingCriteriaParameters {

View File

@ -129,7 +129,6 @@ impl Client {
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
fsm_grammar_state: 0,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,

View File

@ -46,7 +46,6 @@ impl Health {
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
fsm_grammar_state: 0,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,

View File

@ -369,7 +369,6 @@ mod tests {
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
fsm_grammar_state: 0,
},
stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false,

View File

@ -293,9 +293,6 @@ impl Validation {
.validate_input(request.inputs, truncate, max_new_tokens)
.await?;
// init the start state of the grammar
let fsm_grammar_state = 0;
let parameters = NextTokenChooserParameters {
temperature,
repetition_penalty,
@ -307,7 +304,6 @@ impl Validation {
seed,
watermark,
grammar,
fsm_grammar_state,
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,

View File

@ -1,6 +1,7 @@
import math
import torch
import json
from loguru import logger
from functools import lru_cache
from typing import Optional, List, Dict, Union
@ -477,18 +478,19 @@ class GrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device):
self.device = device
self.tokenizer = tokenizer
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_state, grammar):
if fsm_grammar_state == -1:
# todo mask for only eos token
def __call__(
self,
_input_ids: torch.Tensor,
logits: torch.Tensor,
fsm_grammar_state: int,
grammar: str,
):
if fsm_grammar_state == -1 or grammar == "":
return logits
# if grammar is '' or None, return the greedy token
if grammar == "" or grammar is None:
return logits
fsm = self.compile_fsm(grammar, self.tokenizer)
fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer)
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_state)
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0
@ -502,20 +504,24 @@ class GrammarLogitProcessor(LogitsProcessor):
if grammar == "" or grammar is None:
return fsm_grammar_state
fsm = self.compile_fsm(grammar, self.tokenizer)
fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer)
return fsm.next_state(fsm_grammar_state, next_token_id)
@staticmethod
@lru_cache(maxsize=32, typed=True)
def compile_fsm(self, schema, tokenizer):
def _cached_compile_fsm(self, schema, tokenizer):
start_time = time.time()
tokenizer = self.adapt_tokenizer(tokenizer)
is_json_string = schema.startswith("{") and schema.endswith("}")
regex_string = build_regex_from_object(schema) if is_json_string else schema
fsm = RegexFSM(regex_string, tokenizer)
try:
json.loads(schema) # check if schema is a valid json
schema = build_regex_from_object(schema) # convert schema to regex
except json.JSONDecodeError:
pass
fsm = RegexFSM(schema, tokenizer)
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm
def adapt_tokenizer(self, tokenizer):
@staticmethod
def adapt_tokenizer(tokenizer):
"""Adapt tokenizer to work with the FSM.
The API of Outlines tokenizers is slightly different to that of
@ -548,68 +554,31 @@ class GrammarLogitProcessor(LogitsProcessor):
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device):
self.device = device
self.tokenizer = tokenizer
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_states, grammars):
def __call__(
self,
_input_ids: torch.Tensor,
logits: torch.Tensor,
fsm_grammar_states: List[int],
grammars: List[str],
):
for i in range(logits.shape[0]):
if fsm_grammar_states[i] == -1:
# todo mask for only eos token
if fsm_grammar_states[i] == -1 or grammars[i] == "":
continue
# if grammar is '' or None, return the greedy token
if grammars[i] == "" or grammars[i] is None:
continue
fsm = self.compile_fsm(grammars[i], self.tokenizer)
fsm = GrammarLogitProcessor._cached_compile_fsm(
self, grammars[i], self.tokenizer
)
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0
biased_scores = logits[i] + mask
logits[i] = biased_scores
return logits
def advance(self, next_token_id, fsm_grammar_state, grammar):
if fsm_grammar_state == -1:
return fsm_grammar_state
if grammar == "" or grammar is None:
return fsm_grammar_state
fsm = self.compile_fsm(grammar, self.tokenizer)
return fsm.next_state(fsm_grammar_state, next_token_id)
@lru_cache(maxsize=32, typed=True)
def compile_fsm(self, schema, tokenizer):
start_time = time.time()
tokenizer = self.adapt_tokenizer(tokenizer)
is_json_string = schema.startswith("{") and schema.endswith("}")
regex_string = build_regex_from_object(schema) if is_json_string else schema
fsm = RegexFSM(regex_string, tokenizer)
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm
def adapt_tokenizer(self, tokenizer):
"""Adapt tokenizer to work with the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
tokenizer.convert_token_to_string = convert_token_to_string
return tokenizer
def advance(self, next_token_ids, fsm_grammar_states, grammars):
return GrammarLogitProcessor.advance(
self, next_token_ids, fsm_grammar_states, grammars
)

View File

@ -1,7 +1,5 @@
import re
from typing import Callable, List, Optional, Tuple, DefaultDict
from collections import defaultdict
import math
from typing import List, Optional, Tuple, DefaultDict
import torch
from text_generation_server.pb import generate_pb2
@ -38,7 +36,7 @@ class NextTokenChooser:
device: str = "cpu",
tokenizer: Optional[PreTrainedTokenizerBase] = None,
grammar: str = "",
fsm_grammar_state: Optional[DefaultDict[int, int]] = None,
fsm_grammar_state: int = 0,
):
self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None
@ -124,7 +122,6 @@ class NextTokenChooser:
device=device,
tokenizer=tokenizer,
grammar=pb.grammar,
fsm_grammar_state=pb.fsm_grammar_state,
)
@ -229,9 +226,9 @@ class HeterogeneousNextTokenChooser:
typical_p: List[float],
do_sample: List[bool],
seeds: List[int],
tokenizer=None,
grammars=None,
fsm_grammar_states=None,
tokenizer: PreTrainedTokenizerBase,
grammars: List[str],
fsm_grammar_states=List[int],
):
warpers = []
@ -395,15 +392,9 @@ class HeterogeneousNextTokenChooser:
# advance the grammar state
if self.grammar_processor is not None:
for i in range(len(self.fsm_grammar_states)):
try:
self.fsm_grammar_states[i] = self.grammar_processor.advance(
next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i]
)
except:
import ipdb
ipdb.set_trace()
pass
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
@ -465,7 +456,7 @@ class HeterogeneousNextTokenChooser:
dtype=dtype,
tokenizer=tokenizer,
grammars=[pb_.grammar for pb_ in pb],
fsm_grammar_states=[pb_.fsm_grammar_state for pb_ in pb],
fsm_grammar_states=[0] * len(pb),
)