feat: first draft constraining generation via outlines

This commit is contained in:
drbh 2024-02-01 23:13:34 +00:00
parent 4c2848b24b
commit b013cb4f4a
2 changed files with 74 additions and 3 deletions

View File

@ -87,7 +87,7 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )

View File

@ -1,5 +1,7 @@
import re import re
from typing import List, Optional, Tuple from typing import Callable, List, Optional, Tuple, DefaultDict
from collections import defaultdict
import math
import torch import torch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
@ -18,6 +20,7 @@ from text_generation_server.utils.logits_process import (
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
from outlines.fsm.fsm import RegexFSM
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
@ -32,6 +35,7 @@ class NextTokenChooser:
do_sample=False, do_sample=False,
seed=0, seed=0,
device="cpu", device="cpu",
tokenizer=None,
): ):
self.watermark_processor = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
@ -46,6 +50,7 @@ class NextTokenChooser:
if frequency_penalty and frequency_penalty != 0.0 if frequency_penalty and frequency_penalty != 0.0
else None else None
) )
self.tokenizer = tokenizer
has_warpers = ( has_warpers = (
(temperature is not None and temperature != 1.0) (temperature is not None and temperature != 1.0)
@ -61,7 +66,9 @@ class NextTokenChooser:
self.static_warper = None self.static_warper = None
sampling = do_sample or has_warpers sampling = do_sample or has_warpers
self.choice = Sampling(seed, device) if sampling else Greedy() # TODO toggle grammar
# self.choice = Sampling(seed, device) if sampling else Greedy()
self.choice = Grammar(tokenizer, device)
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
if self.watermark_processor is not None: if self.watermark_processor is not None:
@ -85,6 +92,7 @@ class NextTokenChooser:
cls, cls,
pb: generate_pb2.NextTokenChooserParameters, pb: generate_pb2.NextTokenChooserParameters,
device: torch.device, device: torch.device,
tokenizer: PreTrainedTokenizerBase,
) -> "NextTokenChooser": ) -> "NextTokenChooser":
return NextTokenChooser( return NextTokenChooser(
watermark=pb.watermark, watermark=pb.watermark,
@ -97,6 +105,7 @@ class NextTokenChooser:
do_sample=pb.do_sample, do_sample=pb.do_sample,
seed=pb.seed, seed=pb.seed,
device=device, device=device,
tokenizer=tokenizer,
) )
@ -419,6 +428,68 @@ class Greedy:
def __call__(self, logits): def __call__(self, logits):
return logits.argmax(dim=-1) return logits.argmax(dim=-1)
# TODO: move this whole thing into the logit_process util and make it a Sampler
class Grammar:
fsm_state: DefaultDict[int, int]
fsm: RegexFSM
def __init__(self, tokenizer, device):
# TODO: get regex on init not hardcoded
regex_str = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
# TODO: adapt tokenizer is expensive, we should do it only once
# this is a temporary solution
tokenizer = self.adapt_tokenizer(tokenizer)
fsm = RegexFSM(regex_str, tokenizer)
self.fsm = fsm
self.fsm_state = defaultdict(int)
def __call__(self, logits):
# TODO: handle seq_id properly
seq_id = 0
if self.fsm_state[seq_id] == -1:
return self.fsm_state[seq_id].eos_token_id
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
mask = torch.full((logits.shape[-1],), -math.inf, device=logits.device)
mask[allowed_tokens] = 0
biased_scores = logits + mask
# greedly pick the token with the highest score
greedy = biased_scores.argmax(dim=-1)
# now update the fsm state
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[seq_id], greedy.item()
)
return greedy
def adapt_tokenizer(self, tokenizer):
"""Adapt tokenizer to work with the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
tokenizer.convert_token_to_string = convert_token_to_string
return tokenizer
class HeterogeneousSampling: class HeterogeneousSampling:
r""" r"""