mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: first draft constraining generation via outlines
This commit is contained in:
parent
4c2848b24b
commit
b013cb4f4a
@ -87,7 +87,7 @@ class CausalLMBatch(Batch):
|
|||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
from typing import List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple, DefaultDict
|
||||||
|
from collections import defaultdict
|
||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
@ -18,6 +20,7 @@ from text_generation_server.utils.logits_process import (
|
|||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||||
|
|
||||||
|
from outlines.fsm.fsm import RegexFSM
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -32,6 +35,7 @@ class NextTokenChooser:
|
|||||||
do_sample=False,
|
do_sample=False,
|
||||||
seed=0,
|
seed=0,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
|
tokenizer=None,
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
@ -46,6 +50,7 @@ class NextTokenChooser:
|
|||||||
if frequency_penalty and frequency_penalty != 0.0
|
if frequency_penalty and frequency_penalty != 0.0
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
has_warpers = (
|
has_warpers = (
|
||||||
(temperature is not None and temperature != 1.0)
|
(temperature is not None and temperature != 1.0)
|
||||||
@ -61,7 +66,9 @@ class NextTokenChooser:
|
|||||||
self.static_warper = None
|
self.static_warper = None
|
||||||
|
|
||||||
sampling = do_sample or has_warpers
|
sampling = do_sample or has_warpers
|
||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
# TODO toggle grammar
|
||||||
|
# self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
|
self.choice = Grammar(tokenizer, device)
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
@ -85,6 +92,7 @@ class NextTokenChooser:
|
|||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.NextTokenChooserParameters,
|
pb: generate_pb2.NextTokenChooserParameters,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> "NextTokenChooser":
|
) -> "NextTokenChooser":
|
||||||
return NextTokenChooser(
|
return NextTokenChooser(
|
||||||
watermark=pb.watermark,
|
watermark=pb.watermark,
|
||||||
@ -97,6 +105,7 @@ class NextTokenChooser:
|
|||||||
do_sample=pb.do_sample,
|
do_sample=pb.do_sample,
|
||||||
seed=pb.seed,
|
seed=pb.seed,
|
||||||
device=device,
|
device=device,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -419,6 +428,68 @@ class Greedy:
|
|||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
return logits.argmax(dim=-1)
|
return logits.argmax(dim=-1)
|
||||||
|
|
||||||
|
# TODO: move this whole thing into the logit_process util and make it a Sampler
|
||||||
|
class Grammar:
|
||||||
|
fsm_state: DefaultDict[int, int]
|
||||||
|
fsm: RegexFSM
|
||||||
|
|
||||||
|
def __init__(self, tokenizer, device):
|
||||||
|
# TODO: get regex on init not hardcoded
|
||||||
|
regex_str = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||||
|
|
||||||
|
# TODO: adapt tokenizer is expensive, we should do it only once
|
||||||
|
# this is a temporary solution
|
||||||
|
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||||
|
fsm = RegexFSM(regex_str, tokenizer)
|
||||||
|
self.fsm = fsm
|
||||||
|
self.fsm_state = defaultdict(int)
|
||||||
|
|
||||||
|
def __call__(self, logits):
|
||||||
|
# TODO: handle seq_id properly
|
||||||
|
seq_id = 0
|
||||||
|
|
||||||
|
if self.fsm_state[seq_id] == -1:
|
||||||
|
return self.fsm_state[seq_id].eos_token_id
|
||||||
|
|
||||||
|
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
|
||||||
|
mask = torch.full((logits.shape[-1],), -math.inf, device=logits.device)
|
||||||
|
mask[allowed_tokens] = 0
|
||||||
|
biased_scores = logits + mask
|
||||||
|
|
||||||
|
# greedly pick the token with the highest score
|
||||||
|
greedy = biased_scores.argmax(dim=-1)
|
||||||
|
|
||||||
|
# now update the fsm state
|
||||||
|
self.fsm_state[seq_id] = self.fsm.next_state(
|
||||||
|
self.fsm_state[seq_id], greedy.item()
|
||||||
|
)
|
||||||
|
return greedy
|
||||||
|
|
||||||
|
def adapt_tokenizer(self, tokenizer):
|
||||||
|
"""Adapt tokenizer to work with the FSM.
|
||||||
|
|
||||||
|
The API of Outlines tokenizers is slightly different to that of
|
||||||
|
`transformers`. In addition we need to handle the missing spaces to
|
||||||
|
Llama's tokenizer to be able to compile FSMs for this model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||||
|
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||||
|
|
||||||
|
def convert_token_to_string(token: str) -> str:
|
||||||
|
from transformers.file_utils import SPIECE_UNDERLINE
|
||||||
|
|
||||||
|
string = tokenizer.convert_tokens_to_string([token])
|
||||||
|
|
||||||
|
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||||
|
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||||
|
return " " + string
|
||||||
|
|
||||||
|
return string
|
||||||
|
|
||||||
|
tokenizer.convert_token_to_string = convert_token_to_string
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
class HeterogeneousSampling:
|
class HeterogeneousSampling:
|
||||||
r"""
|
r"""
|
||||||
|
Loading…
Reference in New Issue
Block a user