mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: remove states from proto and simplify logit processor
This commit is contained in:
parent
a28ba7212c
commit
13e07b8257
@ -46,7 +46,6 @@ pub async fn run(
|
|||||||
frequency_penalty: frequency_penalty.unwrap_or(0.0),
|
frequency_penalty: frequency_penalty.unwrap_or(0.0),
|
||||||
watermark,
|
watermark,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
fsm_grammar_state: 0,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize terminal properties
|
// Initialize terminal properties
|
||||||
|
@ -376,6 +376,7 @@ def generate_load():
|
|||||||
n: int,
|
n: int,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
grammar: Optional[str] = None,
|
grammar: Optional[str] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
) -> List[Response]:
|
) -> List[Response]:
|
||||||
futures = [
|
futures = [
|
||||||
client.generate(
|
client.generate(
|
||||||
@ -384,6 +385,7 @@ def generate_load():
|
|||||||
decoder_input_details=True,
|
decoder_input_details=True,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
)
|
)
|
||||||
for _ in range(n)
|
for _ in range(n)
|
||||||
]
|
]
|
||||||
|
@ -72,8 +72,6 @@ message NextTokenChooserParameters {
|
|||||||
bool watermark = 8;
|
bool watermark = 8;
|
||||||
/// grammar (applied if not empty)
|
/// grammar (applied if not empty)
|
||||||
string grammar = 10;
|
string grammar = 10;
|
||||||
/// fsm_grammar_state
|
|
||||||
uint32 fsm_grammar_state = 11;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoppingCriteriaParameters {
|
message StoppingCriteriaParameters {
|
||||||
|
@ -129,7 +129,6 @@ impl Client {
|
|||||||
frequency_penalty: 0.1,
|
frequency_penalty: 0.1,
|
||||||
watermark: true,
|
watermark: true,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
fsm_grammar_state: 0,
|
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: max_total_tokens - truncate,
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
|
@ -46,7 +46,6 @@ impl Health {
|
|||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
fsm_grammar_state: 0,
|
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
@ -369,7 +369,6 @@ mod tests {
|
|||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
fsm_grammar_state: 0,
|
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
|
@ -293,9 +293,6 @@ impl Validation {
|
|||||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// init the start state of the grammar
|
|
||||||
let fsm_grammar_state = 0;
|
|
||||||
|
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
@ -307,7 +304,6 @@ impl Validation {
|
|||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
grammar,
|
grammar,
|
||||||
fsm_grammar_state,
|
|
||||||
};
|
};
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import json
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional, List, Dict, Union
|
from typing import Optional, List, Dict, Union
|
||||||
@ -477,18 +478,19 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
def __init__(self, tokenizer, device):
|
def __init__(self, tokenizer, device):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_state, grammar):
|
def __call__(
|
||||||
if fsm_grammar_state == -1:
|
self,
|
||||||
# todo mask for only eos token
|
_input_ids: torch.Tensor,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
fsm_grammar_state: int,
|
||||||
|
grammar: str,
|
||||||
|
):
|
||||||
|
if fsm_grammar_state == -1 or grammar == "":
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
# if grammar is '' or None, return the greedy token
|
fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer)
|
||||||
if grammar == "" or grammar is None:
|
|
||||||
return logits
|
|
||||||
|
|
||||||
fsm = self.compile_fsm(grammar, self.tokenizer)
|
|
||||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_state)
|
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_state)
|
||||||
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||||
mask[allowed_tokens] = 0
|
mask[allowed_tokens] = 0
|
||||||
@ -502,20 +504,24 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
if grammar == "" or grammar is None:
|
if grammar == "" or grammar is None:
|
||||||
return fsm_grammar_state
|
return fsm_grammar_state
|
||||||
|
|
||||||
fsm = self.compile_fsm(grammar, self.tokenizer)
|
fsm = GrammarLogitProcessor._cached_compile_fsm(self, grammar, self.tokenizer)
|
||||||
return fsm.next_state(fsm_grammar_state, next_token_id)
|
return fsm.next_state(fsm_grammar_state, next_token_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
@lru_cache(maxsize=32, typed=True)
|
@lru_cache(maxsize=32, typed=True)
|
||||||
def compile_fsm(self, schema, tokenizer):
|
def _cached_compile_fsm(self, schema, tokenizer):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
try:
|
||||||
is_json_string = schema.startswith("{") and schema.endswith("}")
|
json.loads(schema) # check if schema is a valid json
|
||||||
regex_string = build_regex_from_object(schema) if is_json_string else schema
|
schema = build_regex_from_object(schema) # convert schema to regex
|
||||||
fsm = RegexFSM(regex_string, tokenizer)
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
fsm = RegexFSM(schema, tokenizer)
|
||||||
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
||||||
return fsm
|
return fsm
|
||||||
|
|
||||||
def adapt_tokenizer(self, tokenizer):
|
@staticmethod
|
||||||
|
def adapt_tokenizer(tokenizer):
|
||||||
"""Adapt tokenizer to work with the FSM.
|
"""Adapt tokenizer to work with the FSM.
|
||||||
|
|
||||||
The API of Outlines tokenizers is slightly different to that of
|
The API of Outlines tokenizers is slightly different to that of
|
||||||
@ -548,68 +554,31 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||||
def __init__(self, tokenizer, device):
|
def __init__(self, tokenizer, device):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = GrammarLogitProcessor.adapt_tokenizer(tokenizer)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_states, grammars):
|
def __call__(
|
||||||
|
self,
|
||||||
|
_input_ids: torch.Tensor,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
fsm_grammar_states: List[int],
|
||||||
|
grammars: List[str],
|
||||||
|
):
|
||||||
for i in range(logits.shape[0]):
|
for i in range(logits.shape[0]):
|
||||||
if fsm_grammar_states[i] == -1:
|
if fsm_grammar_states[i] == -1 or grammars[i] == "":
|
||||||
# todo mask for only eos token
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# if grammar is '' or None, return the greedy token
|
fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||||
if grammars[i] == "" or grammars[i] is None:
|
self, grammars[i], self.tokenizer
|
||||||
continue
|
)
|
||||||
|
|
||||||
fsm = self.compile_fsm(grammars[i], self.tokenizer)
|
|
||||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||||
|
|
||||||
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||||
mask[allowed_tokens] = 0
|
mask[allowed_tokens] = 0
|
||||||
biased_scores = logits[i] + mask
|
biased_scores = logits[i] + mask
|
||||||
logits[i] = biased_scores
|
logits[i] = biased_scores
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def advance(self, next_token_id, fsm_grammar_state, grammar):
|
def advance(self, next_token_ids, fsm_grammar_states, grammars):
|
||||||
if fsm_grammar_state == -1:
|
return GrammarLogitProcessor.advance(
|
||||||
return fsm_grammar_state
|
self, next_token_ids, fsm_grammar_states, grammars
|
||||||
|
)
|
||||||
if grammar == "" or grammar is None:
|
|
||||||
return fsm_grammar_state
|
|
||||||
|
|
||||||
fsm = self.compile_fsm(grammar, self.tokenizer)
|
|
||||||
return fsm.next_state(fsm_grammar_state, next_token_id)
|
|
||||||
|
|
||||||
@lru_cache(maxsize=32, typed=True)
|
|
||||||
def compile_fsm(self, schema, tokenizer):
|
|
||||||
start_time = time.time()
|
|
||||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
|
||||||
is_json_string = schema.startswith("{") and schema.endswith("}")
|
|
||||||
regex_string = build_regex_from_object(schema) if is_json_string else schema
|
|
||||||
fsm = RegexFSM(regex_string, tokenizer)
|
|
||||||
logger.info(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
|
||||||
return fsm
|
|
||||||
|
|
||||||
def adapt_tokenizer(self, tokenizer):
|
|
||||||
"""Adapt tokenizer to work with the FSM.
|
|
||||||
|
|
||||||
The API of Outlines tokenizers is slightly different to that of
|
|
||||||
`transformers`. In addition we need to handle the missing spaces to
|
|
||||||
Llama's tokenizer to be able to compile FSMs for this model.
|
|
||||||
|
|
||||||
"""
|
|
||||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
|
||||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
|
||||||
|
|
||||||
def convert_token_to_string(token: str) -> str:
|
|
||||||
from transformers.file_utils import SPIECE_UNDERLINE
|
|
||||||
|
|
||||||
string = tokenizer.convert_tokens_to_string([token])
|
|
||||||
|
|
||||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
|
||||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
|
||||||
return " " + string
|
|
||||||
|
|
||||||
return string
|
|
||||||
|
|
||||||
tokenizer.convert_token_to_string = convert_token_to_string
|
|
||||||
|
|
||||||
return tokenizer
|
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Callable, List, Optional, Tuple, DefaultDict
|
from typing import List, Optional, Tuple, DefaultDict
|
||||||
from collections import defaultdict
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
@ -38,7 +36,7 @@ class NextTokenChooser:
|
|||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||||
grammar: str = "",
|
grammar: str = "",
|
||||||
fsm_grammar_state: Optional[DefaultDict[int, int]] = None,
|
fsm_grammar_state: int = 0,
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
@ -124,7 +122,6 @@ class NextTokenChooser:
|
|||||||
device=device,
|
device=device,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
grammar=pb.grammar,
|
grammar=pb.grammar,
|
||||||
fsm_grammar_state=pb.fsm_grammar_state,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -229,9 +226,9 @@ class HeterogeneousNextTokenChooser:
|
|||||||
typical_p: List[float],
|
typical_p: List[float],
|
||||||
do_sample: List[bool],
|
do_sample: List[bool],
|
||||||
seeds: List[int],
|
seeds: List[int],
|
||||||
tokenizer=None,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
grammars=None,
|
grammars: List[str],
|
||||||
fsm_grammar_states=None,
|
fsm_grammar_states=List[int],
|
||||||
):
|
):
|
||||||
warpers = []
|
warpers = []
|
||||||
|
|
||||||
@ -395,15 +392,9 @@ class HeterogeneousNextTokenChooser:
|
|||||||
# advance the grammar state
|
# advance the grammar state
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
for i in range(len(self.fsm_grammar_states)):
|
for i in range(len(self.fsm_grammar_states)):
|
||||||
try:
|
|
||||||
self.fsm_grammar_states[i] = self.grammar_processor.advance(
|
self.fsm_grammar_states[i] = self.grammar_processor.advance(
|
||||||
next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i]
|
next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i]
|
||||||
)
|
)
|
||||||
except:
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
ipdb.set_trace()
|
|
||||||
pass
|
|
||||||
|
|
||||||
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
||||||
|
|
||||||
@ -465,7 +456,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
grammars=[pb_.grammar for pb_ in pb],
|
grammars=[pb_.grammar for pb_ in pb],
|
||||||
fsm_grammar_states=[pb_.fsm_grammar_state for pb_ in pb],
|
fsm_grammar_states=[0] * len(pb),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user