mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: prefer grammar as logit processor
This commit is contained in:
parent
a1c630d5c1
commit
f1d43f2df4
@ -99,9 +99,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Maximum number of blocks
|
# Maximum number of blocks
|
||||||
max_blocks: int
|
max_blocks: int
|
||||||
|
|
||||||
# The states for the grammar FSM
|
|
||||||
fsm_states: Dict[int, int] = None
|
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
return generate_pb2.CachedBatch(
|
return generate_pb2.CachedBatch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
@ -140,7 +137,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
fsm_states = {}
|
|
||||||
|
|
||||||
all_prefill_logprobs = True
|
all_prefill_logprobs = True
|
||||||
no_prefill_logprobs = True
|
no_prefill_logprobs = True
|
||||||
@ -323,7 +319,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
speculative_ids=None,
|
speculative_ids=None,
|
||||||
fsm_states=fsm_states,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
|
@ -4,6 +4,12 @@ import torch
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional, List, Dict, Union
|
from typing import Optional, List, Dict, Union
|
||||||
|
|
||||||
|
from outlines.fsm.fsm import RegexFSM
|
||||||
|
from outlines.fsm.json_schema import build_regex_from_object
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import List, Optional, DefaultDict
|
||||||
|
import time
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
LogitsWarper,
|
LogitsWarper,
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
@ -135,9 +141,7 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
score = torch.gather(scores, 1, input_ids)
|
score = torch.gather(scores, 1, input_ids)
|
||||||
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
||||||
score = -torch.where(
|
score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||||
score < 0, score * self.penalty, score / self.penalty
|
|
||||||
)
|
|
||||||
|
|
||||||
return scores.scatter_add_(1, input_ids, score)
|
return scores.scatter_add_(1, input_ids, score)
|
||||||
|
|
||||||
@ -464,3 +468,147 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
|
|||||||
self.processors = new_processors
|
self.processors = new_processors
|
||||||
return self
|
return self
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class GrammarLogitProcessor(LogitsProcessor):
|
||||||
|
fsm_state: DefaultDict[int, int]
|
||||||
|
fsm: RegexFSM
|
||||||
|
|
||||||
|
def __init__(self, tokenizer, device):
|
||||||
|
self.device = device
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_state, grammar):
|
||||||
|
if fsm_grammar_state == -1:
|
||||||
|
# todo mask for only eos token
|
||||||
|
return logits
|
||||||
|
|
||||||
|
# if grammar is '' or None, return the greedy token
|
||||||
|
if grammar == "" or grammar is None:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
fsm = self.compile_fsm(grammar, self.tokenizer)
|
||||||
|
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_state)
|
||||||
|
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||||
|
mask[allowed_tokens] = 0
|
||||||
|
biased_scores = logits + mask
|
||||||
|
return biased_scores
|
||||||
|
|
||||||
|
def advance(self, next_token_id, fsm_grammar_state, grammar):
|
||||||
|
if fsm_grammar_state == -1:
|
||||||
|
return fsm_grammar_state
|
||||||
|
|
||||||
|
if grammar == "" or grammar is None:
|
||||||
|
return fsm_grammar_state
|
||||||
|
|
||||||
|
fsm = self.compile_fsm(grammar, self.tokenizer)
|
||||||
|
return fsm.next_state(fsm_grammar_state, next_token_id)
|
||||||
|
|
||||||
|
@lru_cache(maxsize=32, typed=True)
|
||||||
|
def compile_fsm(self, schema, tokenizer):
|
||||||
|
start_time = time.time()
|
||||||
|
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||||
|
is_json_string = schema.startswith("{") and schema.endswith("}")
|
||||||
|
regex_string = build_regex_from_object(schema) if is_json_string else schema
|
||||||
|
fsm = RegexFSM(regex_string, tokenizer)
|
||||||
|
print(f"Compile FSM: {time.time() - start_time}")
|
||||||
|
return fsm
|
||||||
|
|
||||||
|
def adapt_tokenizer(self, tokenizer):
|
||||||
|
"""Adapt tokenizer to work with the FSM.
|
||||||
|
|
||||||
|
The API of Outlines tokenizers is slightly different to that of
|
||||||
|
`transformers`. In addition we need to handle the missing spaces to
|
||||||
|
Llama's tokenizer to be able to compile FSMs for this model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||||
|
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||||
|
|
||||||
|
def convert_token_to_string(token: str) -> str:
|
||||||
|
from transformers.file_utils import SPIECE_UNDERLINE
|
||||||
|
|
||||||
|
string = tokenizer.convert_tokens_to_string([token])
|
||||||
|
|
||||||
|
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||||
|
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||||
|
return " " + string
|
||||||
|
|
||||||
|
return string
|
||||||
|
|
||||||
|
tokenizer.convert_token_to_string = convert_token_to_string
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def filter(self, indices, fsm_grammar_states, grammars):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||||
|
def __init__(self, tokenizer, device):
|
||||||
|
self.device = device
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, logits, fsm_grammar_states, grammars):
|
||||||
|
for i in range(len(logits)):
|
||||||
|
if fsm_grammar_states[i] == -1:
|
||||||
|
# todo mask for only eos token
|
||||||
|
continue
|
||||||
|
|
||||||
|
# if grammar is '' or None, return the greedy token
|
||||||
|
if grammars[i] == "" or grammars[i] is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
fsm = self.compile_fsm(grammars[i], self.tokenizer)
|
||||||
|
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||||
|
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||||
|
mask[allowed_tokens] = 0
|
||||||
|
biased_scores = logits[i] + mask
|
||||||
|
logits[i] = biased_scores
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def advance(self, next_token_id, fsm_grammar_state, grammar):
|
||||||
|
if fsm_grammar_state == -1:
|
||||||
|
return fsm_grammar_state
|
||||||
|
|
||||||
|
if grammar == "" or grammar is None:
|
||||||
|
return fsm_grammar_state
|
||||||
|
|
||||||
|
fsm = self.compile_fsm(grammar, self.tokenizer)
|
||||||
|
return fsm.next_state(fsm_grammar_state, next_token_id)
|
||||||
|
|
||||||
|
@lru_cache(maxsize=32, typed=True)
|
||||||
|
def compile_fsm(self, schema, tokenizer):
|
||||||
|
start_time = time.time()
|
||||||
|
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||||
|
is_json_string = schema.startswith("{") and schema.endswith("}")
|
||||||
|
regex_string = build_regex_from_object(schema) if is_json_string else schema
|
||||||
|
fsm = RegexFSM(regex_string, tokenizer)
|
||||||
|
# print(f"Compile FSM: {time.time() - start_time}")
|
||||||
|
return fsm
|
||||||
|
|
||||||
|
def adapt_tokenizer(self, tokenizer):
|
||||||
|
"""Adapt tokenizer to work with the FSM.
|
||||||
|
|
||||||
|
The API of Outlines tokenizers is slightly different to that of
|
||||||
|
`transformers`. In addition we need to handle the missing spaces to
|
||||||
|
Llama's tokenizer to be able to compile FSMs for this model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||||
|
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||||
|
|
||||||
|
def convert_token_to_string(token: str) -> str:
|
||||||
|
from transformers.file_utils import SPIECE_UNDERLINE
|
||||||
|
|
||||||
|
string = tokenizer.convert_tokens_to_string([token])
|
||||||
|
|
||||||
|
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||||
|
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||||
|
return " " + string
|
||||||
|
|
||||||
|
return string
|
||||||
|
|
||||||
|
tokenizer.convert_token_to_string = convert_token_to_string
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
@ -8,6 +8,7 @@ from text_generation_server.pb import generate_pb2
|
|||||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
from text_generation_server.utils.logits_process import (
|
from text_generation_server.utils.logits_process import (
|
||||||
FrequencyPenaltyLogitsProcessor,
|
FrequencyPenaltyLogitsProcessor,
|
||||||
|
GrammarLogitProcessor,
|
||||||
HeterogeneousProcessorWrapper,
|
HeterogeneousProcessorWrapper,
|
||||||
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||||
HeterogeneousFrequencyPenaltyLogitsProcessor,
|
HeterogeneousFrequencyPenaltyLogitsProcessor,
|
||||||
@ -15,6 +16,7 @@ from text_generation_server.utils.logits_process import (
|
|||||||
HeterogeneousTopKLogitsWarper,
|
HeterogeneousTopKLogitsWarper,
|
||||||
HeterogeneousTopPLogitsWarper,
|
HeterogeneousTopPLogitsWarper,
|
||||||
HeterogeneousTypicalLogitsWarper,
|
HeterogeneousTypicalLogitsWarper,
|
||||||
|
HeterogeneousGrammarLogitProcessor,
|
||||||
static_warper,
|
static_warper,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
@ -58,6 +60,11 @@ class NextTokenChooser:
|
|||||||
if frequency_penalty and frequency_penalty != 0.0
|
if frequency_penalty and frequency_penalty != 0.0
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
self.grammar_processor = (
|
||||||
|
GrammarLogitProcessor(tokenizer, device)
|
||||||
|
if grammar and grammar != ""
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
has_warpers = (
|
has_warpers = (
|
||||||
@ -75,14 +82,9 @@ class NextTokenChooser:
|
|||||||
|
|
||||||
sampling = do_sample or has_warpers
|
sampling = do_sample or has_warpers
|
||||||
|
|
||||||
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
self.fsm_grammar_state = fsm_grammar_state
|
self.fsm_grammar_state = fsm_grammar_state
|
||||||
self.grammars = grammar
|
self.grammar = grammar
|
||||||
|
|
||||||
# TODO: is grammar a subset of sampling? If so, we should merge them
|
|
||||||
if grammar:
|
|
||||||
self.choice = Grammar(tokenizer, device, grammar)
|
|
||||||
else:
|
|
||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
@ -91,15 +93,23 @@ class NextTokenChooser:
|
|||||||
scores = self.repetition_processor(input_ids, scores)
|
scores = self.repetition_processor(input_ids, scores)
|
||||||
if self.frequency_processor is not None:
|
if self.frequency_processor is not None:
|
||||||
scores = self.frequency_processor(input_ids, scores)
|
scores = self.frequency_processor(input_ids, scores)
|
||||||
|
if self.grammar_processor is not None:
|
||||||
|
scores = self.grammar_processor(
|
||||||
|
input_ids, scores, self.fsm_grammar_state, self.grammar
|
||||||
|
)
|
||||||
|
|
||||||
if self.static_warper is None:
|
if self.static_warper is None:
|
||||||
next_logprob = torch.log_softmax(scores, -1)
|
next_logprob = torch.log_softmax(scores, -1)
|
||||||
else:
|
else:
|
||||||
scores, next_logprob = self.static_warper(scores)
|
scores, next_logprob = self.static_warper(scores)
|
||||||
|
|
||||||
next_id = self.choice(scores[-1], self.fsm_grammar_state, self.grammars).view(
|
next_id = self.choice(scores[-1]).view(1, 1)
|
||||||
1, 1
|
|
||||||
)
|
if self.grammar_processor is not None:
|
||||||
|
next_state = self.grammar_processor.advance(
|
||||||
|
next_id.item(), self.fsm_grammar_state, self.grammar
|
||||||
|
)
|
||||||
|
self.fsm_grammar_state = next_state
|
||||||
|
|
||||||
return next_id, next_logprob
|
return next_id, next_logprob
|
||||||
|
|
||||||
@ -229,7 +239,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
do_sample: List[bool],
|
do_sample: List[bool],
|
||||||
seeds: List[int],
|
seeds: List[int],
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
grammar=None,
|
grammars=None,
|
||||||
fsm_grammar_states=None,
|
fsm_grammar_states=None,
|
||||||
):
|
):
|
||||||
warpers = []
|
warpers = []
|
||||||
@ -262,6 +272,12 @@ class HeterogeneousNextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.grammar_processor = (
|
||||||
|
HeterogeneousGrammarLogitProcessor(tokenizer, device)
|
||||||
|
if any([grammar != "" and grammar is not None for grammar in grammars])
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if any([x != 1.0 for x in temperature]):
|
if any([x != 1.0 for x in temperature]):
|
||||||
do_sample = [
|
do_sample = [
|
||||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||||
@ -284,21 +300,18 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
self.warpers = warpers
|
self.warpers = warpers
|
||||||
|
|
||||||
if grammar is not None:
|
if any(do_sample):
|
||||||
self.choice = Grammar(tokenizer, device)
|
|
||||||
elif any(do_sample):
|
|
||||||
self.choice = HeterogeneousSampling(do_sample, seeds, device)
|
self.choice = HeterogeneousSampling(do_sample, seeds, device)
|
||||||
else:
|
else:
|
||||||
self.choice = Greedy()
|
self.choice = Greedy()
|
||||||
|
|
||||||
self.use_grammar = grammar is not None
|
|
||||||
self.seeds = seeds
|
self.seeds = seeds
|
||||||
self.do_sample = do_sample
|
self.do_sample = do_sample
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.fsm_grammar_states = fsm_grammar_states
|
self.fsm_grammar_states = fsm_grammar_states
|
||||||
self.grammars = grammar
|
self.grammars = grammars
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -327,11 +340,13 @@ class HeterogeneousNextTokenChooser:
|
|||||||
_scores = self.repetition_processor(input_ids, _scores)
|
_scores = self.repetition_processor(input_ids, _scores)
|
||||||
if self.frequency_processor is not None:
|
if self.frequency_processor is not None:
|
||||||
_scores = self.frequency_processor(input_ids, _scores)
|
_scores = self.frequency_processor(input_ids, _scores)
|
||||||
|
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
_scores = warper(input_ids, _scores)
|
_scores = warper(input_ids, _scores)
|
||||||
|
if self.grammar_processor is not None:
|
||||||
_next_ids = self.choice(_scores, self.fsm_grammar_states, self.grammars)
|
_scores = self.grammar_processor(
|
||||||
|
input_ids, _scores, self.fsm_grammar_states, self.grammars
|
||||||
|
)
|
||||||
|
_next_ids = self.choice(_scores)
|
||||||
scores[:, j] = _scores
|
scores[:, j] = _scores
|
||||||
next_ids[:, j] = _next_ids
|
next_ids[:, j] = _next_ids
|
||||||
next_ids = next_ids.view(B * S)
|
next_ids = next_ids.view(B * S)
|
||||||
@ -386,6 +401,19 @@ class HeterogeneousNextTokenChooser:
|
|||||||
else:
|
else:
|
||||||
speculative_ids = None
|
speculative_ids = None
|
||||||
|
|
||||||
|
# advance the grammar state
|
||||||
|
if self.grammar_processor is not None:
|
||||||
|
for i in range(len(self.fsm_grammar_states)):
|
||||||
|
try:
|
||||||
|
self.fsm_grammar_states[i] = self.grammar_processor.advance(
|
||||||
|
next_ids[i].item(), self.fsm_grammar_states[i], self.grammars[i]
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
|
pass
|
||||||
|
|
||||||
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
@ -408,12 +436,16 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.seeds = [self.seeds[i] for i in indices]
|
self.seeds = [self.seeds[i] for i in indices]
|
||||||
self.do_sample = [self.do_sample[i] for i in indices]
|
self.do_sample = [self.do_sample[i] for i in indices]
|
||||||
|
|
||||||
if self.use_grammar or any(self.do_sample):
|
new_grammars = []
|
||||||
_, new_fsm_grammar_states, new_grammars = self.choice.filter(indices, self.fsm_grammar_states, self.grammars)
|
new_fsm_grammar_states = []
|
||||||
self.fsm_grammar_states = new_fsm_grammar_states
|
for i in indices:
|
||||||
self.grammars = new_grammars
|
new_grammars.append(self.grammars[i])
|
||||||
else:
|
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
|
||||||
self.choice = Greedy()
|
|
||||||
|
self.grammars = new_grammars
|
||||||
|
self.fsm_grammar_states = new_fsm_grammar_states
|
||||||
|
|
||||||
|
self.choice = Greedy()
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -438,7 +470,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
grammar=[pb_.grammar for pb_ in pb],
|
grammars=[pb_.grammar for pb_ in pb],
|
||||||
fsm_grammar_states=[pb_.fsm_grammar_state for pb_ in pb],
|
fsm_grammar_states=[pb_.fsm_grammar_state for pb_ in pb],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -449,7 +481,7 @@ class Sampling:
|
|||||||
self.generator.manual_seed(seed)
|
self.generator.manual_seed(seed)
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
def __call__(self, logits):
|
def __call__(self, logits, *args):
|
||||||
probs = torch.nn.functional.softmax(logits, -1)
|
probs = torch.nn.functional.softmax(logits, -1)
|
||||||
# Avoid GPU<->CPU sync done by torch multinomial
|
# Avoid GPU<->CPU sync done by torch multinomial
|
||||||
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
|
||||||
@ -464,97 +496,6 @@ class Greedy:
|
|||||||
def filter(self, indices, *args):
|
def filter(self, indices, *args):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
class Grammar:
|
|
||||||
fsm_state: DefaultDict[int, int]
|
|
||||||
fsm: RegexFSM
|
|
||||||
|
|
||||||
def __init__(self, tokenizer, device):
|
|
||||||
self.device = device
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
|
|
||||||
def __call__(self, logits, fsm_grammar_states, grammars):
|
|
||||||
empty = torch.ones(logits.shape[0], dtype=torch.int64, device=logits.device)
|
|
||||||
try:
|
|
||||||
for i in range(len(fsm_grammar_states)):
|
|
||||||
if fsm_grammar_states[i] == -1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# if grammar is '' or None, return the greedy token
|
|
||||||
if grammars[i] == "" or grammars[i] is None:
|
|
||||||
empty[i] = logits[i].argmax().item()
|
|
||||||
continue
|
|
||||||
|
|
||||||
# this is cached and should be fast after the first time
|
|
||||||
fsm = self.compile_fsm(grammars[i], self.tokenizer)
|
|
||||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
|
||||||
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
|
||||||
mask[allowed_tokens] = 0
|
|
||||||
biased_scores = logits[i : i + 1] + mask
|
|
||||||
greedy = biased_scores.argmax(dim=-1)
|
|
||||||
|
|
||||||
# if greedy is empty, return the eos token
|
|
||||||
if greedy.shape[0] == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
fsm_grammar_states[i] = fsm.next_state(
|
|
||||||
fsm_grammar_states[i], greedy.item()
|
|
||||||
)
|
|
||||||
|
|
||||||
empty[i] = greedy.item()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Exception: {e}")
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
ipdb.set_trace()
|
|
||||||
return empty
|
|
||||||
|
|
||||||
@lru_cache(maxsize=32, typed=True)
|
|
||||||
def compile_fsm(self, schema, tokenizer):
|
|
||||||
start_time = time.time()
|
|
||||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
|
||||||
is_json_string = schema.startswith("{") and schema.endswith("}")
|
|
||||||
regex_string = build_regex_from_object(schema) if is_json_string else schema
|
|
||||||
fsm = RegexFSM(regex_string, tokenizer)
|
|
||||||
print(f"Compile FSM: {time.time() - start_time}")
|
|
||||||
return fsm
|
|
||||||
|
|
||||||
def adapt_tokenizer(self, tokenizer):
|
|
||||||
"""Adapt tokenizer to work with the FSM.
|
|
||||||
|
|
||||||
The API of Outlines tokenizers is slightly different to that of
|
|
||||||
`transformers`. In addition we need to handle the missing spaces to
|
|
||||||
Llama's tokenizer to be able to compile FSMs for this model.
|
|
||||||
|
|
||||||
"""
|
|
||||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
|
||||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
|
||||||
|
|
||||||
def convert_token_to_string(token: str) -> str:
|
|
||||||
from transformers.file_utils import SPIECE_UNDERLINE
|
|
||||||
|
|
||||||
string = tokenizer.convert_tokens_to_string([token])
|
|
||||||
|
|
||||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
|
||||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
|
||||||
return " " + string
|
|
||||||
|
|
||||||
return string
|
|
||||||
|
|
||||||
tokenizer.convert_token_to_string = convert_token_to_string
|
|
||||||
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
def filter(self, indices, fsm_grammar_states, grammars):
|
|
||||||
new_fsm_grammar_states = []
|
|
||||||
new_grammars = []
|
|
||||||
|
|
||||||
for i in indices:
|
|
||||||
new_fsm_grammar_states.append(fsm_grammar_states[i])
|
|
||||||
new_grammars.append(grammars[i])
|
|
||||||
|
|
||||||
return self, new_fsm_grammar_states, new_grammars
|
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousSampling:
|
class HeterogeneousSampling:
|
||||||
r"""
|
r"""
|
||||||
@ -574,7 +515,7 @@ class HeterogeneousSampling:
|
|||||||
|
|
||||||
self.greedy = Greedy()
|
self.greedy = Greedy()
|
||||||
|
|
||||||
def __call__(self, logits, fsm_grammar_states, grammars):
|
def __call__(self, logits):
|
||||||
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||||
if self.greedy_indices:
|
if self.greedy_indices:
|
||||||
# Computing for all indices is faster than slicing
|
# Computing for all indices is faster than slicing
|
||||||
|
Loading…
Reference in New Issue
Block a user