mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: support other models and add fsm caching
This commit is contained in:
parent
56e919e459
commit
8fd2664a3c
@ -237,7 +237,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device
|
next_token_chooser_parameters, dtype, device, tokenizer
|
||||||
)
|
)
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
@ -593,6 +593,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
next_token_chooser_parameters,
|
next_token_chooser_parameters,
|
||||||
dtype=batches[0].next_token_chooser.dtype,
|
dtype=batches[0].next_token_chooser.dtype,
|
||||||
device=batches[0].next_token_chooser.device,
|
device=batches[0].next_token_chooser.device,
|
||||||
|
tokenizer=batches[0].next_token_chooser.tokenizer,
|
||||||
|
grammar=batches[0].requests.parameters.grammar,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_ids = (
|
speculative_ids = (
|
||||||
|
@ -192,7 +192,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device
|
next_token_chooser_parameters, dtype, device, tokenizer
|
||||||
)
|
)
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
|
@ -124,7 +124,7 @@ class MambaBatch(Batch):
|
|||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
@ -22,6 +22,7 @@ from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcess
|
|||||||
|
|
||||||
from outlines.fsm.fsm import RegexFSM
|
from outlines.fsm.fsm import RegexFSM
|
||||||
from outlines.fsm.json_schema import build_regex_from_object
|
from outlines.fsm.json_schema import build_regex_from_object
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
# TODO: remove when done debugging
|
# TODO: remove when done debugging
|
||||||
import time
|
import time
|
||||||
@ -219,6 +220,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
typical_p: List[float],
|
typical_p: List[float],
|
||||||
do_sample: List[bool],
|
do_sample: List[bool],
|
||||||
seeds: List[int],
|
seeds: List[int],
|
||||||
|
tokenizer=None,
|
||||||
|
grammar=None,
|
||||||
):
|
):
|
||||||
warpers = []
|
warpers = []
|
||||||
|
|
||||||
@ -272,11 +275,15 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
self.warpers = warpers
|
self.warpers = warpers
|
||||||
|
|
||||||
if any(do_sample):
|
first_grammar = grammar[0] if grammar else None
|
||||||
|
if first_grammar:
|
||||||
|
self.choice = Grammar(tokenizer, device, first_grammar)
|
||||||
|
elif any(do_sample):
|
||||||
self.choice = HeterogeneousSampling(do_sample, seeds, device)
|
self.choice = HeterogeneousSampling(do_sample, seeds, device)
|
||||||
else:
|
else:
|
||||||
self.choice = Greedy()
|
self.choice = Greedy()
|
||||||
|
|
||||||
|
self.use_grammar = grammar is not None
|
||||||
self.seeds = seeds
|
self.seeds = seeds
|
||||||
self.do_sample = do_sample
|
self.do_sample = do_sample
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
@ -390,7 +397,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.seeds = [self.seeds[i] for i in indices]
|
self.seeds = [self.seeds[i] for i in indices]
|
||||||
self.do_sample = [self.do_sample[i] for i in indices]
|
self.do_sample = [self.do_sample[i] for i in indices]
|
||||||
|
|
||||||
if any(self.do_sample):
|
if self.use_grammar or any(self.do_sample):
|
||||||
self.choice.filter(indices)
|
self.choice.filter(indices)
|
||||||
else:
|
else:
|
||||||
self.choice = Greedy()
|
self.choice = Greedy()
|
||||||
@ -403,6 +410,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
pb: List[generate_pb2.NextTokenChooserParameters],
|
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> "HeterogeneousNextTokenChooser":
|
) -> "HeterogeneousNextTokenChooser":
|
||||||
return HeterogeneousNextTokenChooser(
|
return HeterogeneousNextTokenChooser(
|
||||||
watermark=[pb_.watermark for pb_ in pb],
|
watermark=[pb_.watermark for pb_ in pb],
|
||||||
@ -416,6 +424,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
seeds=[pb_.seed for pb_ in pb],
|
seeds=[pb_.seed for pb_ in pb],
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
grammar=[pb_.grammar for pb_ in pb],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -443,22 +453,12 @@ class Grammar:
|
|||||||
fsm: RegexFSM
|
fsm: RegexFSM
|
||||||
|
|
||||||
def __init__(self, tokenizer, device, grammar):
|
def __init__(self, tokenizer, device, grammar):
|
||||||
# TODO: remove debug logs
|
fsm = self.compile_fsm(grammar, tokenizer)
|
||||||
start_time = time.time()
|
|
||||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
|
||||||
print(f"Adapt tokenizer: {time.time() - start_time}")
|
|
||||||
start_time = time.time()
|
|
||||||
regex_string = build_regex_from_object(grammar)
|
|
||||||
print(f"Build regex: {time.time() - start_time}")
|
|
||||||
fsm = RegexFSM(regex_string, tokenizer)
|
|
||||||
print(f"Compile FSM: {time.time() - start_time}")
|
|
||||||
|
|
||||||
self.fsm = fsm
|
self.fsm = fsm
|
||||||
self.fsm_state = defaultdict(int)
|
self.fsm_state = defaultdict(int)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
# TODO: handle seq_id properly
|
|
||||||
seq_id = 0
|
seq_id = 0
|
||||||
|
|
||||||
if self.fsm_state[seq_id] == -1:
|
if self.fsm_state[seq_id] == -1:
|
||||||
@ -477,6 +477,17 @@ class Grammar:
|
|||||||
self.fsm_state[seq_id], greedy.item()
|
self.fsm_state[seq_id], greedy.item()
|
||||||
)
|
)
|
||||||
return greedy
|
return greedy
|
||||||
|
|
||||||
|
@lru_cache(maxsize=32, typed=True)
|
||||||
|
def compile_fsm(self, schema, tokenizer):
|
||||||
|
start_time = time.time()
|
||||||
|
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||||
|
is_json_string = schema.startswith("{") and schema.endswith("}")
|
||||||
|
regex_string = build_regex_from_object(schema) if is_json_string else schema
|
||||||
|
fsm = RegexFSM(regex_string, tokenizer)
|
||||||
|
print(f"Compile FSM: {time.time() - start_time}")
|
||||||
|
return fsm
|
||||||
|
|
||||||
|
|
||||||
def adapt_tokenizer(self, tokenizer):
|
def adapt_tokenizer(self, tokenizer):
|
||||||
"""Adapt tokenizer to work with the FSM.
|
"""Adapt tokenizer to work with the FSM.
|
||||||
|
Loading…
Reference in New Issue
Block a user