From b013cb4f4a9beb64e64f0809d095973549f7d951 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 1 Feb 2024 23:13:34 +0000 Subject: [PATCH] feat: first draft constraining generation via outlines --- .../models/causal_lm.py | 2 +- server/text_generation_server/utils/tokens.py | 75 ++++++++++++++++++- 2 files changed, 74 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index a7a16212..f692a7ec 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -87,7 +87,7 @@ class CausalLMBatch(Batch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index d6ca10c7..3d160de7 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,5 +1,7 @@ import re -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, DefaultDict +from collections import defaultdict +import math import torch from text_generation_server.pb import generate_pb2 @@ -18,6 +20,7 @@ from text_generation_server.utils.logits_process import ( from text_generation_server.utils.watermark import WatermarkLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor +from outlines.fsm.fsm import RegexFSM class NextTokenChooser: def __init__( @@ -32,6 +35,7 @@ class NextTokenChooser: do_sample=False, seed=0, device="cpu", + tokenizer=None, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -46,6 +50,7 @@ class NextTokenChooser: if frequency_penalty and frequency_penalty != 0.0 else None ) + self.tokenizer = tokenizer has_warpers = ( (temperature is not None and temperature != 1.0) @@ -61,7 +66,9 @@ class NextTokenChooser: self.static_warper = None sampling = do_sample or has_warpers - self.choice = Sampling(seed, device) if sampling else Greedy() + # TODO toggle grammar + # self.choice = Sampling(seed, device) if sampling else Greedy() + self.choice = Grammar(tokenizer, device) def __call__(self, input_ids, scores): if self.watermark_processor is not None: @@ -85,6 +92,7 @@ class NextTokenChooser: cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device, + tokenizer: PreTrainedTokenizerBase, ) -> "NextTokenChooser": return NextTokenChooser( watermark=pb.watermark, @@ -97,6 +105,7 @@ class NextTokenChooser: do_sample=pb.do_sample, seed=pb.seed, device=device, + tokenizer=tokenizer, ) @@ -419,7 +428,69 @@ class Greedy: def __call__(self, logits): return logits.argmax(dim=-1) +# TODO: move this whole thing into the logit_process util and make it a Sampler +class Grammar: + fsm_state: DefaultDict[int, int] + fsm: RegexFSM + def __init__(self, tokenizer, device): + # TODO: get regex on init not hardcoded + regex_str = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" + + # TODO: adapt tokenizer is expensive, we should do it only once + # this is a temporary solution + tokenizer = self.adapt_tokenizer(tokenizer) + fsm = RegexFSM(regex_str, tokenizer) + self.fsm = fsm + self.fsm_state = defaultdict(int) + + def __call__(self, logits): + # TODO: handle seq_id properly + seq_id = 0 + + if self.fsm_state[seq_id] == -1: + return self.fsm_state[seq_id].eos_token_id + + allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) + mask = torch.full((logits.shape[-1],), -math.inf, device=logits.device) + mask[allowed_tokens] = 0 + biased_scores = logits + mask + + # greedly pick the token with the highest score + greedy = biased_scores.argmax(dim=-1) + + # now update the fsm state + self.fsm_state[seq_id] = self.fsm.next_state( + self.fsm_state[seq_id], greedy.item() + ) + return greedy + + def adapt_tokenizer(self, tokenizer): + """Adapt tokenizer to work with the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. In addition we need to handle the missing spaces to + Llama's tokenizer to be able to compile FSMs for this model. + + """ + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + + return tokenizer + class HeterogeneousSampling: r""" Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.