mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: prefer grammar as logit processor
This commit is contained in:
parent
a1c630d5c1
commit
f1d43f2df4
@ -99,9 +99,6 @@ class FlashCausalLMBatch(Batch):
|
||||
# Maximum number of blocks
|
||||
max_blocks: int
|
||||
|
||||
# The states for the grammar FSM
|
||||
fsm_states: Dict[int, int] = None
|
||||
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
@ -140,7 +137,6 @@ class FlashCausalLMBatch(Batch):
|
||||
read_offsets = []
|
||||
all_input_ids = []
|
||||
requests_idx_mapping = {}
|
||||
fsm_states = {}
|
||||
|
||||
all_prefill_logprobs = True
|
||||
no_prefill_logprobs = True
|
||||
@ -323,7 +319,6 @@ class FlashCausalLMBatch(Batch):
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=None,
|
||||
fsm_states=fsm_states,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
|
@ -4,6 +4,12 @@ import torch
|
||||
from functools import lru_cache
|
||||
from typing import Optional, List, Dict, Union
|
||||
|
||||
from outlines.fsm.fsm import RegexFSM
|
||||
from outlines.fsm.json_schema import build_regex_from_object
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, DefaultDict
|
||||
import time
|
||||
|
||||
from transformers import (
|
||||
LogitsWarper,
|
||||
LogitsProcessor,
|
||||
@ -135,9 +141,7 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||
) -> torch.FloatTensor:
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
||||
score = -torch.where(
|
||||
score < 0, score * self.penalty, score / self.penalty
|
||||
)
|
||||
score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||
|
||||
return scores.scatter_add_(1, input_ids, score)
|
||||
|
||||
@ -464,3 +468,147 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
|
||||
self.processors = new_processors
|
||||
return self
|
||||
return None
|
||||
|
||||
|
||||
class GrammarLogitProcessor(LogitsProcessor):
|
||||
fsm_state: DefaultDict[int, int]
|
||||
fsm: RegexFSM
|
||||
|
||||
def __init__(self, tokenizer, device):
|
||||
self.device = device
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_state, grammar):
|
||||
if fsm_grammar_state == -1:
|
||||
# todo mask for only eos token
|
||||
return logits
|
||||
|
||||
# if grammar is '' or None, return the greedy token
|
||||
if grammar == "" or grammar is None:
|
||||
return logits
|
||||
|
||||
fsm = self.compile_fsm(grammar, self.tokenizer)
|
||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_state)
|
||||
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||
mask[allowed_tokens] = 0
|
||||
biased_scores = logits + mask
|
||||
return biased_scores
|
||||
|
||||
def advance(self, next_token_id, fsm_grammar_state, grammar):
|
||||
if fsm_grammar_state == -1:
|
||||
return fsm_grammar_state
|
||||
|
||||
if grammar == "" or grammar is None:
|
||||
return fsm_grammar_state
|
||||
|
||||
fsm = self.compile_fsm(grammar, self.tokenizer)
|
||||
return fsm.next_state(fsm_grammar_state, next_token_id)
|
||||
|
||||
@lru_cache(maxsize=32, typed=True)
|
||||
def compile_fsm(self, schema, tokenizer):
|
||||
start_time = time.time()
|
||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||
is_json_string = schema.startswith("{") and schema.endswith("}")
|
||||
regex_string = build_regex_from_object(schema) if is_json_string else schema
|
||||
fsm = RegexFSM(regex_string, tokenizer)
|
||||
print(f"Compile FSM: {time.time() - start_time}")
|
||||
return fsm
|
||||
|
||||
def adapt_tokenizer(self, tokenizer):
|
||||
"""Adapt tokenizer to work with the FSM.
|
||||
|
||||
The API of Outlines tokenizers is slightly different to that of
|
||||
`transformers`. In addition we need to handle the missing spaces to
|
||||
Llama's tokenizer to be able to compile FSMs for this model.
|
||||
|
||||
"""
|
||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
|
||||
return tokenizer
|
||||
|
||||
def filter(self, indices, fsm_grammar_states, grammars):
|
||||
return self
|
||||
|
||||
|
||||
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
def __init__(self, tokenizer, device):
|
||||
self.device = device
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_states, grammars):
|
||||
for i in range(len(logits)):
|
||||
if fsm_grammar_states[i] == -1:
|
||||
# todo mask for only eos token
|
||||
continue
|
||||
|
||||
# if grammar is '' or None, return the greedy token
|
||||
if grammars[i] == "" or grammars[i] is None:
|
||||
continue
|
||||
|
||||
fsm = self.compile_fsm(grammars[i], self.tokenizer)
|
||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||
mask[allowed_tokens] = 0
|
||||
biased_scores = logits[i] + mask
|
||||
logits[i] = biased_scores
|
||||
return logits
|
||||
|
||||
def advance(self, next_token_id, fsm_grammar_state, grammar):
|
||||
if fsm_grammar_state == -1:
|
||||
return fsm_grammar_state
|
||||
|
||||
if grammar == "" or grammar is None:
|
||||
return fsm_grammar_state
|
||||
|
||||
fsm = self.compile_fsm(grammar, self.tokenizer)
|
||||
return fsm.next_state(fsm_grammar_state, next_token_id)
|
||||
|
||||
@lru_cache(maxsize=32, typed=True)
|
||||
def compile_fsm(self, schema, tokenizer):
|
||||
start_time = time.time()
|
||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||
is_json_string = schema.startswith("{") and schema.endswith("}")
|
||||
regex_string = build_regex_from_object(schema) if is_json_string else schema
|
||||
fsm = RegexFSM(regex_string, tokenizer)
|
||||
# print(f"Compile FSM: {time.time() - start_time}")
|
||||
return fsm
|
||||
|
||||
def adapt_tokenizer(self, tokenizer):
|
||||
"""Adapt tokenizer to work with the FSM.
|
||||
|
||||
The API of Outlines tokenizers is slightly different to that of
|
||||
`transformers`. In addition we need to handle the missing spaces to
|
||||
Llama's tokenizer to be able to compile FSMs for this model.
|
||||
|
||||
"""
|
||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
|
||||
return tokenizer
|
||||
|
@ -8,6 +8,7 @@ from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||
from text_generation_server.utils.logits_process import (
|
||||
FrequencyPenaltyLogitsProcessor,
|
||||
GrammarLogitProcessor,
|
||||
HeterogeneousProcessorWrapper,
|
||||
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||
HeterogeneousFrequencyPenaltyLogitsProcessor,
|
||||
@ -15,6 +16,7 @@ from text_generation_server.utils.logits_process import (
|
||||
HeterogeneousTopKLogitsWarper,
|
||||
HeterogeneousTopPLogitsWarper,
|
||||
HeterogeneousTypicalLogitsWarper,
|
||||
HeterogeneousGrammarLogitProcessor,
|
||||
static_warper,
|
||||
)
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
@ -58,6 +60,11 @@ class NextTokenChooser:
|
||||
if frequency_penalty and frequency_penalty != 0.0
|
||||
else None
|
||||
)
|
||||
self.grammar_processor = (
|
||||
GrammarLogitProcessor(tokenizer, device)
|
||||
if grammar and grammar != ""
|
||||
else None
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
has_warpers = (
|
||||
@ -75,14 +82,9 @@ class NextTokenChooser:
|
||||
|
||||
sampling = do_sample or has_warpers
|
||||
|
||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||
self.fsm_grammar_state = fsm_grammar_state
|
||||
self.grammars = grammar
|
||||
|
||||
# TODO: is grammar a subset of sampling? If so, we should merge them
|
||||
if grammar:
|
||||
self.choice = Grammar(tokenizer, device, grammar)
|
||||
else:
|
||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||
self.grammar = grammar
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
if self.watermark_processor is not None:
|
||||
@ -91,15 +93,23 @@ class NextTokenChooser:
|
||||
scores = self.repetition_processor(input_ids, scores)
|
||||
if self.frequency_processor is not None:
|
||||
scores = self.frequency_processor(input_ids, scores)
|
||||
if self.grammar_processor is not None:
|
||||
scores = self.grammar_processor(
|
||||
input_ids, scores, self.fsm_grammar_state, self.grammar
|
||||
)
|
||||
|
||||
if self.static_warper is None:
|
||||
next_logprob = torch.log_softmax(scores, -1)
|
||||
else:
|
||||
scores, next_logprob = self.static_warper(scores)
|
||||
|
||||
next_id = self.choice(scores[-1], self.fsm_grammar_state, self.grammars).view(
|
||||
1, 1
|
||||
)
|
||||
next_id = self.choice(scores[-1]).view(1, 1)
|
||||
|
||||
if self.grammar_processor is not None:
|
||||
next_state = self.grammar_processor.advance(
|
||||
next_id.item(), self.fsm_grammar_state, self.grammar
|
||||
)
|
||||
self.fsm_grammar_state = next_state
|
||||
|
||||
return next_id, next_logprob
|
||||
|
||||
@ -229,7 +239,7 @@ class HeterogeneousNextTokenChooser:
|
||||
do_sample: List[bool],
|
||||
seeds: List[int],
|
||||
tokenizer=None,
|
||||
grammar=None,
|
||||
grammars=None,
|
||||
fsm_grammar_states=None,
|
||||
):
|
||||
warpers = []
|
||||
@ -262,6 +272,12 @@ class HeterogeneousNextTokenChooser:
|
||||
else None
|
||||
)
|
||||
|
||||
self.grammar_processor = (
|
||||
HeterogeneousGrammarLogitProcessor(tokenizer, device)
|
||||
if any([grammar != "" and grammar is not None for grammar in grammars])
|
||||
else None
|
||||
)
|
||||
|
||||
if any([x != 1.0 for x in temperature]):
|
||||
do_sample = [
|
||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||
@ -284,21 +300,18 @@ class HeterogeneousNextTokenChooser:
|
||||
|
||||
self.warpers = warpers
|
||||
|
||||
if grammar is not None:
|
||||
self.choice = Grammar(tokenizer, device)
|
||||
elif any(do_sample):
|
||||
if any(do_sample):
|
||||
self.choice = HeterogeneousSampling(do_sample, seeds, device)
|
||||
else:
|
||||
self.choice = Greedy()
|
||||
|
||||
self.use_grammar = grammar is not None
|
||||
self.seeds = seeds
|
||||
self.do_sample = do_sample
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tokenizer = tokenizer
|
||||
self.fsm_grammar_states = fsm_grammar_states
|
||||
self.grammars = grammar
|
||||
self.grammars = grammars
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -327,11 +340,13 @@ class HeterogeneousNextTokenChooser:
|
||||
_scores = self.repetition_processor(input_ids, _scores)
|
||||
if self.frequency_processor is not None:
|
||||
_scores = self.frequency_processor(input_ids, _scores)
|
||||
|
||||
for warper in self.warpers:
|
||||
_scores = warper(input_ids, _scores)
|
||||
|
||||
_next_ids = self.choice(_scores, self.fsm_grammar_states, self.grammars)
|
||||
if self.grammar_processor is not None:
|
||||
_scores = self.grammar_processor(
|
||||
input_ids, _scores, self.fsm_grammar_states, self.grammars
|
||||
)
|
||||
_next_ids = self.choice(_scores)
|
||||
scores[:, j] = _scores
|
||||
next_ids[:, j] = _next_ids
|
||||
next_ids = next_ids.view(B * S)
|
||||
@ -386,6 +401,19 @@ class HeterogeneousNextTokenChooser:
|
||||
else:
|
||||
speculative_ids = None
|
||||
|
||||
# advance the grammar state
|
||||
if self.grammar_processor is not None:
|
||||
for i in range(len(self.fsm_grammar_states)):
|
||||
try:
|
||||
self.fsm_grammar_states[i] = self.grammar_processor.advance(
|
||||
next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i]
|
||||
)
|
||||
except:
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
pass
|
||||
|
||||
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
||||
|
||||
def filter(self, indices):
|
||||
@ -408,12 +436,16 @@ class HeterogeneousNextTokenChooser:
|
||||
self.seeds = [self.seeds[i] for i in indices]
|
||||
self.do_sample = [self.do_sample[i] for i in indices]
|
||||
|
||||
if self.use_grammar or any(self.do_sample):
|
||||
_, new_fsm_grammar_states, new_grammars = self.choice.filter(indices, self.fsm_grammar_states, self.grammars)
|
||||
self.fsm_grammar_states = new_fsm_grammar_states
|
||||
self.grammars = new_grammars
|
||||
else:
|
||||
self.choice = Greedy()
|
||||
new_grammars = []
|
||||
new_fsm_grammar_states = []
|
||||
for i in indices:
|
||||
new_grammars.append(self.grammars[i])
|
||||
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
|
||||
|
||||
self.grammars = new_grammars
|
||||
self.fsm_grammar_states = new_fsm_grammar_states
|
||||
|
||||
self.choice = Greedy()
|
||||
|
||||
return self
|
||||
|
||||
@ -438,7 +470,7 @@ class HeterogeneousNextTokenChooser:
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
tokenizer=tokenizer,
|
||||
grammar=[pb_.grammar for pb_ in pb],
|
||||
grammars=[pb_.grammar for pb_ in pb],
|
||||
fsm_grammar_states=[pb_.fsm_grammar_state for pb_ in pb],
|
||||
)
|
||||
|
||||
@ -449,7 +481,7 @@ class Sampling:
|
||||
self.generator.manual_seed(seed)
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, logits):
|
||||
def __call__(self, logits, *args):
|
||||
probs = torch.nn.functional.softmax(logits, -1)
|
||||
# Avoid GPU<->CPU sync done by torch multinomial
|
||||
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
||||
@ -464,97 +496,6 @@ class Greedy:
|
||||
def filter(self, indices, *args):
|
||||
return self
|
||||
|
||||
class Grammar:
|
||||
fsm_state: DefaultDict[int, int]
|
||||
fsm: RegexFSM
|
||||
|
||||
def __init__(self, tokenizer, device):
|
||||
self.device = device
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, logits, fsm_grammar_states, grammars):
|
||||
empty = torch.ones(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||
try:
|
||||
for i in range(len(fsm_grammar_states)):
|
||||
if fsm_grammar_states[i] == -1:
|
||||
continue
|
||||
|
||||
# if grammar is '' or None, return the greedy token
|
||||
if grammars[i] == "" or grammars[i] is None:
|
||||
empty[i] = logits[i].argmax().item()
|
||||
continue
|
||||
|
||||
# this is cached and should be fast after the first time
|
||||
fsm = self.compile_fsm(grammars[i], self.tokenizer)
|
||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||
mask[allowed_tokens] = 0
|
||||
biased_scores = logits[i : i + 1] + mask
|
||||
greedy = biased_scores.argmax(dim=-1)
|
||||
|
||||
# if greedy is empty, return the eos token
|
||||
if greedy.shape[0] == 0:
|
||||
continue
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
fsm_grammar_states[i] = fsm.next_state(
|
||||
fsm_grammar_states[i], greedy.item()
|
||||
)
|
||||
|
||||
empty[i] = greedy.item()
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}")
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
return empty
|
||||
|
||||
@lru_cache(maxsize=32, typed=True)
|
||||
def compile_fsm(self, schema, tokenizer):
|
||||
start_time = time.time()
|
||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||
is_json_string = schema.startswith("{") and schema.endswith("}")
|
||||
regex_string = build_regex_from_object(schema) if is_json_string else schema
|
||||
fsm = RegexFSM(regex_string, tokenizer)
|
||||
print(f"Compile FSM: {time.time() - start_time}")
|
||||
return fsm
|
||||
|
||||
def adapt_tokenizer(self, tokenizer):
|
||||
"""Adapt tokenizer to work with the FSM.
|
||||
|
||||
The API of Outlines tokenizers is slightly different to that of
|
||||
`transformers`. In addition we need to handle the missing spaces to
|
||||
Llama's tokenizer to be able to compile FSMs for this model.
|
||||
|
||||
"""
|
||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
|
||||
return tokenizer
|
||||
|
||||
def filter(self, indices, fsm_grammar_states, grammars):
|
||||
new_fsm_grammar_states = []
|
||||
new_grammars = []
|
||||
|
||||
for i in indices:
|
||||
new_fsm_grammar_states.append(fsm_grammar_states[i])
|
||||
new_grammars.append(grammars[i])
|
||||
|
||||
return self, new_fsm_grammar_states, new_grammars
|
||||
|
||||
|
||||
class HeterogeneousSampling:
|
||||
r"""
|
||||
@ -574,7 +515,7 @@ class HeterogeneousSampling:
|
||||
|
||||
self.greedy = Greedy()
|
||||
|
||||
def __call__(self, logits, fsm_grammar_states, grammars):
|
||||
def __call__(self, logits):
|
||||
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||
if self.greedy_indices:
|
||||
# Computing for all indices is faster than slicing
|
||||
|
Loading…
Reference in New Issue
Block a user