fix: prefer grammar as logit processor

This commit is contained in:
drbh 2024-02-10 01:41:22 +00:00
parent a1c630d5c1
commit f1d43f2df4
3 changed files with 211 additions and 127 deletions

View File

@ -99,9 +99,6 @@ class FlashCausalLMBatch(Batch):
# Maximum number of blocks # Maximum number of blocks
max_blocks: int max_blocks: int
# The states for the grammar FSM
fsm_states: Dict[int, int] = None
def to_pb(self) -> generate_pb2.CachedBatch: def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch( return generate_pb2.CachedBatch(
id=self.batch_id, id=self.batch_id,
@ -140,7 +137,6 @@ class FlashCausalLMBatch(Batch):
read_offsets = [] read_offsets = []
all_input_ids = [] all_input_ids = []
requests_idx_mapping = {} requests_idx_mapping = {}
fsm_states = {}
all_prefill_logprobs = True all_prefill_logprobs = True
no_prefill_logprobs = True no_prefill_logprobs = True
@ -323,7 +319,6 @@ class FlashCausalLMBatch(Batch):
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=None, speculative_ids=None,
fsm_states=fsm_states,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")

View File

@ -4,6 +4,12 @@ import torch
from functools import lru_cache from functools import lru_cache
from typing import Optional, List, Dict, Union from typing import Optional, List, Dict, Union
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object
from functools import lru_cache
from typing import List, Optional, DefaultDict
import time
from transformers import ( from transformers import (
LogitsWarper, LogitsWarper,
LogitsProcessor, LogitsProcessor,
@ -135,9 +141,7 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
) -> torch.FloatTensor: ) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids) score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability # if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where( score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
score < 0, score * self.penalty, score / self.penalty
)
return scores.scatter_add_(1, input_ids, score) return scores.scatter_add_(1, input_ids, score)
@ -464,3 +468,147 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
self.processors = new_processors self.processors = new_processors
return self return self
return None return None
class GrammarLogitProcessor(LogitsProcessor):
fsm_state: DefaultDict[int, int]
fsm: RegexFSM
def __init__(self, tokenizer, device):
self.device = device
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_state, grammar):
if fsm_grammar_state == -1:
# todo mask for only eos token
return logits
# if grammar is '' or None, return the greedy token
if grammar == "" or grammar is None:
return logits
fsm = self.compile_fsm(grammar, self.tokenizer)
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_state)
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0
biased_scores = logits + mask
return biased_scores
def advance(self, next_token_id, fsm_grammar_state, grammar):
if fsm_grammar_state == -1:
return fsm_grammar_state
if grammar == "" or grammar is None:
return fsm_grammar_state
fsm = self.compile_fsm(grammar, self.tokenizer)
return fsm.next_state(fsm_grammar_state, next_token_id)
@lru_cache(maxsize=32, typed=True)
def compile_fsm(self, schema, tokenizer):
start_time = time.time()
tokenizer = self.adapt_tokenizer(tokenizer)
is_json_string = schema.startswith("{") and schema.endswith("}")
regex_string = build_regex_from_object(schema) if is_json_string else schema
fsm = RegexFSM(regex_string, tokenizer)
print(f"Compile FSM: {time.time() - start_time}")
return fsm
def adapt_tokenizer(self, tokenizer):
"""Adapt tokenizer to work with the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
tokenizer.convert_token_to_string = convert_token_to_string
return tokenizer
def filter(self, indices, fsm_grammar_states, grammars):
return self
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device):
self.device = device
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_states, grammars):
for i in range(len(logits)):
if fsm_grammar_states[i] == -1:
# todo mask for only eos token
continue
# if grammar is '' or None, return the greedy token
if grammars[i] == "" or grammars[i] is None:
continue
fsm = self.compile_fsm(grammars[i], self.tokenizer)
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0
biased_scores = logits[i] + mask
logits[i] = biased_scores
return logits
def advance(self, next_token_id, fsm_grammar_state, grammar):
if fsm_grammar_state == -1:
return fsm_grammar_state
if grammar == "" or grammar is None:
return fsm_grammar_state
fsm = self.compile_fsm(grammar, self.tokenizer)
return fsm.next_state(fsm_grammar_state, next_token_id)
@lru_cache(maxsize=32, typed=True)
def compile_fsm(self, schema, tokenizer):
start_time = time.time()
tokenizer = self.adapt_tokenizer(tokenizer)
is_json_string = schema.startswith("{") and schema.endswith("}")
regex_string = build_regex_from_object(schema) if is_json_string else schema
fsm = RegexFSM(regex_string, tokenizer)
# print(f"Compile FSM: {time.time() - start_time}")
return fsm
def adapt_tokenizer(self, tokenizer):
"""Adapt tokenizer to work with the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
tokenizer.convert_token_to_string = convert_token_to_string
return tokenizer

View File

@ -8,6 +8,7 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.logits_process import ( from text_generation_server.utils.logits_process import (
FrequencyPenaltyLogitsProcessor, FrequencyPenaltyLogitsProcessor,
GrammarLogitProcessor,
HeterogeneousProcessorWrapper, HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousFrequencyPenaltyLogitsProcessor, HeterogeneousFrequencyPenaltyLogitsProcessor,
@ -15,6 +16,7 @@ from text_generation_server.utils.logits_process import (
HeterogeneousTopKLogitsWarper, HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper, HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper, HeterogeneousTypicalLogitsWarper,
HeterogeneousGrammarLogitProcessor,
static_warper, static_warper,
) )
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
@ -58,6 +60,11 @@ class NextTokenChooser:
if frequency_penalty and frequency_penalty != 0.0 if frequency_penalty and frequency_penalty != 0.0
else None else None
) )
self.grammar_processor = (
GrammarLogitProcessor(tokenizer, device)
if grammar and grammar != ""
else None
)
self.tokenizer = tokenizer self.tokenizer = tokenizer
has_warpers = ( has_warpers = (
@ -75,14 +82,9 @@ class NextTokenChooser:
sampling = do_sample or has_warpers sampling = do_sample or has_warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
self.fsm_grammar_state = fsm_grammar_state self.fsm_grammar_state = fsm_grammar_state
self.grammars = grammar self.grammar = grammar
# TODO: is grammar a subset of sampling? If so, we should merge them
if grammar:
self.choice = Grammar(tokenizer, device, grammar)
else:
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
if self.watermark_processor is not None: if self.watermark_processor is not None:
@ -91,15 +93,23 @@ class NextTokenChooser:
scores = self.repetition_processor(input_ids, scores) scores = self.repetition_processor(input_ids, scores)
if self.frequency_processor is not None: if self.frequency_processor is not None:
scores = self.frequency_processor(input_ids, scores) scores = self.frequency_processor(input_ids, scores)
if self.grammar_processor is not None:
scores = self.grammar_processor(
input_ids, scores, self.fsm_grammar_state, self.grammar
)
if self.static_warper is None: if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1) next_logprob = torch.log_softmax(scores, -1)
else: else:
scores, next_logprob = self.static_warper(scores) scores, next_logprob = self.static_warper(scores)
next_id = self.choice(scores[-1], self.fsm_grammar_state, self.grammars).view( next_id = self.choice(scores[-1]).view(1, 1)
1, 1
) if self.grammar_processor is not None:
next_state = self.grammar_processor.advance(
next_id.item(), self.fsm_grammar_state, self.grammar
)
self.fsm_grammar_state = next_state
return next_id, next_logprob return next_id, next_logprob
@ -229,7 +239,7 @@ class HeterogeneousNextTokenChooser:
do_sample: List[bool], do_sample: List[bool],
seeds: List[int], seeds: List[int],
tokenizer=None, tokenizer=None,
grammar=None, grammars=None,
fsm_grammar_states=None, fsm_grammar_states=None,
): ):
warpers = [] warpers = []
@ -262,6 +272,12 @@ class HeterogeneousNextTokenChooser:
else None else None
) )
self.grammar_processor = (
HeterogeneousGrammarLogitProcessor(tokenizer, device)
if any([grammar != "" and grammar is not None for grammar in grammars])
else None
)
if any([x != 1.0 for x in temperature]): if any([x != 1.0 for x in temperature]):
do_sample = [ do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample) sample or x != 1.0 for x, sample in zip(temperature, do_sample)
@ -284,21 +300,18 @@ class HeterogeneousNextTokenChooser:
self.warpers = warpers self.warpers = warpers
if grammar is not None: if any(do_sample):
self.choice = Grammar(tokenizer, device)
elif any(do_sample):
self.choice = HeterogeneousSampling(do_sample, seeds, device) self.choice = HeterogeneousSampling(do_sample, seeds, device)
else: else:
self.choice = Greedy() self.choice = Greedy()
self.use_grammar = grammar is not None
self.seeds = seeds self.seeds = seeds
self.do_sample = do_sample self.do_sample = do_sample
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.fsm_grammar_states = fsm_grammar_states self.fsm_grammar_states = fsm_grammar_states
self.grammars = grammar self.grammars = grammars
def __call__( def __call__(
self, self,
@ -327,11 +340,13 @@ class HeterogeneousNextTokenChooser:
_scores = self.repetition_processor(input_ids, _scores) _scores = self.repetition_processor(input_ids, _scores)
if self.frequency_processor is not None: if self.frequency_processor is not None:
_scores = self.frequency_processor(input_ids, _scores) _scores = self.frequency_processor(input_ids, _scores)
for warper in self.warpers: for warper in self.warpers:
_scores = warper(input_ids, _scores) _scores = warper(input_ids, _scores)
if self.grammar_processor is not None:
_next_ids = self.choice(_scores, self.fsm_grammar_states, self.grammars) _scores = self.grammar_processor(
input_ids, _scores, self.fsm_grammar_states, self.grammars
)
_next_ids = self.choice(_scores)
scores[:, j] = _scores scores[:, j] = _scores
next_ids[:, j] = _next_ids next_ids[:, j] = _next_ids
next_ids = next_ids.view(B * S) next_ids = next_ids.view(B * S)
@ -386,6 +401,19 @@ class HeterogeneousNextTokenChooser:
else: else:
speculative_ids = None speculative_ids = None
# advance the grammar state
if self.grammar_processor is not None:
for i in range(len(self.fsm_grammar_states)):
try:
self.fsm_grammar_states[i] = self.grammar_processor.advance(
next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i]
)
except:
import ipdb
ipdb.set_trace()
pass
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
def filter(self, indices): def filter(self, indices):
@ -408,12 +436,16 @@ class HeterogeneousNextTokenChooser:
self.seeds = [self.seeds[i] for i in indices] self.seeds = [self.seeds[i] for i in indices]
self.do_sample = [self.do_sample[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices]
if self.use_grammar or any(self.do_sample): new_grammars = []
_, new_fsm_grammar_states, new_grammars = self.choice.filter(indices, self.fsm_grammar_states, self.grammars) new_fsm_grammar_states = []
self.fsm_grammar_states = new_fsm_grammar_states for i in indices:
self.grammars = new_grammars new_grammars.append(self.grammars[i])
else: new_fsm_grammar_states.append(self.fsm_grammar_states[i])
self.choice = Greedy()
self.grammars = new_grammars
self.fsm_grammar_states = new_fsm_grammar_states
self.choice = Greedy()
return self return self
@ -438,7 +470,7 @@ class HeterogeneousNextTokenChooser:
device=device, device=device,
dtype=dtype, dtype=dtype,
tokenizer=tokenizer, tokenizer=tokenizer,
grammar=[pb_.grammar for pb_ in pb], grammars=[pb_.grammar for pb_ in pb],
fsm_grammar_states=[pb_.fsm_grammar_state for pb_ in pb], fsm_grammar_states=[pb_.fsm_grammar_state for pb_ in pb],
) )
@ -449,7 +481,7 @@ class Sampling:
self.generator.manual_seed(seed) self.generator.manual_seed(seed)
self.seed = seed self.seed = seed
def __call__(self, logits): def __call__(self, logits, *args):
probs = torch.nn.functional.softmax(logits, -1) probs = torch.nn.functional.softmax(logits, -1)
# Avoid GPU<->CPU sync done by torch multinomial # Avoid GPU<->CPU sync done by torch multinomial
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
@ -464,97 +496,6 @@ class Greedy:
def filter(self, indices, *args): def filter(self, indices, *args):
return self return self
class Grammar:
fsm_state: DefaultDict[int, int]
fsm: RegexFSM
def __init__(self, tokenizer, device):
self.device = device
self.tokenizer = tokenizer
def __call__(self, logits, fsm_grammar_states, grammars):
empty = torch.ones(logits.shape[0], dtype=torch.int64, device=logits.device)
try:
for i in range(len(fsm_grammar_states)):
if fsm_grammar_states[i] == -1:
continue
# if grammar is '' or None, return the greedy token
if grammars[i] == "" or grammars[i] is None:
empty[i] = logits[i].argmax().item()
continue
# this is cached and should be fast after the first time
fsm = self.compile_fsm(grammars[i], self.tokenizer)
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0
biased_scores = logits[i : i + 1] + mask
greedy = biased_scores.argmax(dim=-1)
# if greedy is empty, return the eos token
if greedy.shape[0] == 0:
continue
# import ipdb; ipdb.set_trace()
fsm_grammar_states[i] = fsm.next_state(
fsm_grammar_states[i], greedy.item()
)
empty[i] = greedy.item()
except Exception as e:
print(f"Exception: {e}")
import ipdb
ipdb.set_trace()
return empty
@lru_cache(maxsize=32, typed=True)
def compile_fsm(self, schema, tokenizer):
start_time = time.time()
tokenizer = self.adapt_tokenizer(tokenizer)
is_json_string = schema.startswith("{") and schema.endswith("}")
regex_string = build_regex_from_object(schema) if is_json_string else schema
fsm = RegexFSM(regex_string, tokenizer)
print(f"Compile FSM: {time.time() - start_time}")
return fsm
def adapt_tokenizer(self, tokenizer):
"""Adapt tokenizer to work with the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
tokenizer.convert_token_to_string = convert_token_to_string
return tokenizer
def filter(self, indices, fsm_grammar_states, grammars):
new_fsm_grammar_states = []
new_grammars = []
for i in indices:
new_fsm_grammar_states.append(fsm_grammar_states[i])
new_grammars.append(grammars[i])
return self, new_fsm_grammar_states, new_grammars
class HeterogeneousSampling: class HeterogeneousSampling:
r""" r"""
@ -574,7 +515,7 @@ class HeterogeneousSampling:
self.greedy = Greedy() self.greedy = Greedy()
def __call__(self, logits, fsm_grammar_states, grammars): def __call__(self, logits):
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
if self.greedy_indices: if self.greedy_indices:
# Computing for all indices is faster than slicing # Computing for all indices is faster than slicing