diff --git a/README.md b/README.md index 22670987..706b48d3 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,89 @@ This is a personal copy of [huggingface/text-generation-inference](https://github.com/huggingface/text-generation-inference). It contains experimental changes that are not meant for any serious use. +Specifically it adds: +- context free grammar +- logit bias +- negative prompts + +All 3 can be combined of course. + +### Context free grammar + +This adds EBNF grammar as implemented in [Saibo-creator/transformers-CFG](https://github.com/Saibo-creator/transformers-CFG.git). See that repo for grammar examples etc. + +This feature adds extra parameters: +- boolean - use_grammar_constraint +- string - grammar + +For example one can send the folowing request to generate and receive valid json(with llama2 instruct): +``` +{ + "inputs": "This is a valid json string for http request:", + "parameters": { + "truncate": 1000, + "max_new_tokens": 300, + "use_grammar_constraint": true, + "grammar": "root ::= object\n\nobject ::= \"{\" ws ( string \":\" ws value (\",\" ws string \":\" ws value)* )? \"}\"\n\nvalue ::= object | array | string | number | (\"true\" | \"false\" | \"null\") ws\n\narray ::= \"[\" ws ( value (\",\" ws value)* )? \"]\" ws\n\nstring ::= \"\\\"\" ( [a-zA-Z0-9] )* \"\\\"\" ws\n\nnumber ::= (\"-\"? ([0-9] | [1-9] [0-9]*)) (\".\" [0-9]+)? ([eE] [-+]? [0-9]+)? ws\n\n\nws ::= ([ \\t\\n] ws)?\n" + } +} +``` + +### Negative prompts + +This is a rather expensive and crude implementation that has to run the model to obtain unconditional token biases during processing so it does slow down generation when this feature is used. + +Please see documentation for Huggingface's UnbatchedClassifierFreeGuidanceLogitsProcessor for more info. In summary, guidance scale of 1 does nothing. Anything above 1 increases the strength of the negative prompt, anything below 1 turns it into positive. + +This adds the folowing parameters: +- string - negative_inputs +- float - guidance_scale + +It can be used with for example(with gpt2): +``` +{ + "inputs": "Today, a dragon flew over Paris, France,", + "parameters": { + "truncate": 1000, + "negative_inputs": "A very happy event happened,", + "max_new_tokens": 300, + "guidance_scale": 1.5 + } +} +``` + +### Logit bias + +One can use this for disallowing completely, or adjusting a probability of certain words in the output. + +Additional parameters are: +- list of tuples, each containing a string and a flot number + +the string is a word to bias and the number is the float to make it more of less likely. Some models trained with spaces before tokens like gpt2 require an extra space to be included before each word, others like llama2 don't. + +An example use is(with llama2 instruct): +``` +{ + "inputs": "[INST] <>\nThese are a series of question/answer pairs.\n<>[INST]in zfs, what is the console command to create a snapshot?[/INST]", + "parameters": { + "truncate": 1000, + "max_new_tokens": 300, + "logit_bias":[ + [ + "console", + -10 + ], + [ + "You", + -2 + ] + ] + + } +} +``` + +
diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 68563cb5..3a3e6331 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,5 +1,5 @@ import re -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Dict import torch from text_generation_server.pb import generate_pb2 @@ -14,7 +14,7 @@ from text_generation_server.utils.logits_process import ( static_warper, ) from text_generation_server.utils.watermark import WatermarkLogitsProcessor -from transformers import PreTrainedTokenizerBase,RepetitionPenaltyLogitsProcessor,UnbatchedClassifierFreeGuidanceLogitsProcessor,PreTrainedModel +from transformers import PreTrainedTokenizerBase,RepetitionPenaltyLogitsProcessor,UnbatchedClassifierFreeGuidanceLogitsProcessor,PreTrainedModel, SequenceBiasLogitsProcessor from transformers_cfg.grammar_utils import IncrementalGrammarConstraint from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor @@ -37,6 +37,7 @@ class NextTokenChooser: grammar="", guidance_scale=1.0, negative_inputs="", + logit_bias=None, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -47,6 +48,16 @@ class NextTokenChooser: else None ) + self.sequence_bias_processors = None + if logit_bias is not None and len(logit_bias) > 0 and tokenizer is not None: + bias_sequence = [] + for lb in logit_bias: + bias_sequence.append({tuple(tokenizer([lb.word], add_special_tokens=False).input_ids[0]): float(lb.bias)}) + if len(bias_sequence) > 0: + self.sequence_bias_processors = [] + for bs in bias_sequence: + self.sequence_bias_processors.append(SequenceBiasLogitsProcessor(sequence_bias=bs)) + if use_grammar_constraint: grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer) self.grammar_processor = GrammarConstrainedLogitsProcessor(grammar) @@ -85,6 +96,11 @@ class NextTokenChooser: def __call__(self, input_ids, scores): if self.guidance_scale_processor is not None: scores = self.guidance_scale_processor(input_ids, scores) + if self.sequence_bias_processors is not None and len(self.sequence_bias_processors) > 0: + with open('/tmp/output.txt', 'a') as file: + print("We have:"+str(len(self.sequence_bias_processors))+" sequence bias processors", file=file) + for sbp in self.sequence_bias_processors: + scores = sbp(input_ids, scores) if self.grammar_processor is not None: scores = self.grammar_processor(input_ids, scores) if self.watermark_processor is not None: @@ -126,6 +142,7 @@ class NextTokenChooser: model=model, guidance_scale=pb.guidance_scale, negative_inputs=pb.negative_inputs, + logit_bias=pb.logit_bias, ) @@ -229,6 +246,7 @@ class HeterogeneousNextTokenChooser: grammar: List[str], guidance_scale: List[float], negative_inputs: List[str], + logit_bias: List[List[Dict]], temperature: List[float], repetition_penalty: List[float], top_k: List[int], @@ -259,6 +277,22 @@ class HeterogeneousNextTokenChooser: else None ) + self.sequence_bias_processors = None + if any(lst for lst in logit_bias) and tokenizer is not None: + self.sequence_bias_processors = [] + for logit_bias_stage in logit_bias: + sequence_bias_ps = None + if logit_bias_stage is not None and len(logit_bias_stage) > 0: + bias_sequence = [] + for lb in logit_bias: + bias_sequence.append({tuple(tokenizer([lb.word], add_special_tokens=False).input_ids[0]): float(lb.bias)}) + if len(bias_sequence) > 0: + sequence_bias_ps = [] + for bs in bias_sequence: + sequence_bias_ps.append(SequenceBiasLogitsProcessor(sequence_bias=bs)) + if sequence_bias_ps is not None: + self.sequence_bias_processors.append(HeterogeneousProcessorWrapper(sequence_bias_ps)) + if any(use_grammar_constraint): grammar_processors = { i: GrammarConstrainedLogitsProcessor(IncrementalGrammarConstraint(grammar[i], "root", tokenizer)) @@ -323,6 +357,9 @@ class HeterogeneousNextTokenChooser: _scores = scores[:, j] if self.watermark_processor is not None: _scores = self.watermark_processor(input_ids, _scores) + if self.sequence_bias_processors is not None and len(self.sequence_bias_processors) > 0: + for sbp in self.sequence_bias_processors: + _scores = sbp(input_ids, _scores) if self.repetition_processor is not None: _scores = self.repetition_processor(input_ids, _scores) if self.grammar_processor is not None: @@ -437,6 +474,7 @@ class HeterogeneousNextTokenChooser: grammar=[pb_.grammar for pb_ in pb], guidance_scale=[pb_.guidance_scale for pb_ in pb], negative_inputs=[pb_.negative_inputs for pb_ in pb], + logit_bias=[pb_.logit_bias for pb_ in pb] )