From e749d0cc5a96841175b420594c3891db351a4366 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Olszewski?= Date: Wed, 20 Dec 2023 12:42:56 +0100 Subject: [PATCH] Adding CFG (context free grammar) to TGI --- README.md | 4 +++ proto/generate.proto | 2 ++ router/client/src/client.rs | 2 ++ router/src/health.rs | 2 ++ router/src/lib.rs | 8 +++++ router/src/validation.rs | 4 +++ server/requirements_common.txt | 1 + .../models/causal_lm.py | 4 +-- .../models/flash_causal_lm.py | 5 ++-- .../models/flash_mistral.py | 2 +- .../models/galactica.py | 2 +- .../models/idefics_causal_lm.py | 2 +- .../models/seq2seq_lm.py | 2 +- server/text_generation_server/server.py | 2 +- server/text_generation_server/utils/tokens.py | 29 +++++++++++++++++++ 15 files changed, 62 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index d99b7306..22670987 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ +## Warning + +This is a personal copy of [huggingface/text-generation-inference](https://github.com/huggingface/text-generation-inference). It contains experimental changes that are not meant for any serious use. +
diff --git a/proto/generate.proto b/proto/generate.proto index 1f30df38..13cac7a3 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -68,6 +68,8 @@ message NextTokenChooserParameters { float repetition_penalty = 7; /// token watermarking using "A Watermark for Large Language Models" bool watermark = 8; + bool use_grammar_constraint = 9; + string grammar = 10; } message StoppingCriteriaParameters { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 4723d664..22cf6c11 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -126,6 +126,8 @@ impl Client { seed: 0, repetition_penalty: 1.2, watermark: true, + use_grammar_constraint: false, + grammar: "".to_string(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: max_total_tokens - truncate, diff --git a/router/src/health.rs b/router/src/health.rs index ab290fc1..78c94017 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -44,6 +44,8 @@ impl Health { seed: 0, repetition_penalty: 1.0, watermark: false, + use_grammar_constraint: false, + grammar: "".to_string(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/router/src/lib.rs b/router/src/lib.rs index 898fcd04..8e19605c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -123,6 +123,12 @@ pub(crate) struct GenerateParameters { pub watermark: bool, #[serde(default)] #[schema(default = "true")] + pub use_grammar_constraint: bool, + #[serde(default)] + #[schema(default = "false")] + pub grammar: String, + #[serde(default)] + #[schema(default = "")] pub details: bool, #[serde(default)] #[schema(default = "true")] @@ -158,6 +164,8 @@ fn default_parameters() -> GenerateParameters { stop: Vec::new(), truncate: None, watermark: false, + use_grammar_constraint: false, + grammar: "".to_string(), details: false, decoder_input_details: false, seed: None, diff --git a/router/src/validation.rs b/router/src/validation.rs index 64f25c82..f5a28b94 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -163,6 +163,8 @@ impl Validation { truncate, seed, watermark, + use_grammar_constraint, + grammar, decoder_input_details, top_n_tokens, .. @@ -279,6 +281,8 @@ impl Validation { do_sample, seed, watermark, + use_grammar_constraint, + grammar, }; let stopping_parameters = StoppingCriteriaParameters { max_new_tokens, diff --git a/server/requirements_common.txt b/server/requirements_common.txt index 5a321834..ec17498e 100644 --- a/server/requirements_common.txt +++ b/server/requirements_common.txt @@ -39,6 +39,7 @@ setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13" +git+https://github.com/epfl-dlab/transformers_cfg.git@4d764337d4d0c167a32560e601fc7a5fb9f33e85 typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 7b10256c..97b04b60 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -87,7 +87,7 @@ class CausalLMBatch(Batch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) @@ -258,7 +258,7 @@ class CausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": + def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "CausalLMBatch": # Used for padding total_batch_size = 0 max_input_length = 0 diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 930082cd..cdf3803b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -233,7 +233,7 @@ class FlashCausalLMBatch(Batch): ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device + next_token_chooser_parameters, dtype, device, tokenizer ) start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -468,7 +468,7 @@ class FlashCausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": + def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "FlashCausalLMBatch": # Batch attributes requests = [] requests_idx_mapping = {} @@ -589,6 +589,7 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters, dtype=batches[0].next_token_chooser.dtype, device=batches[0].next_token_chooser.device, + tokenizer=tokenizer, ) speculative_ids = ( diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 8c6cb025..18b98c66 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -190,7 +190,7 @@ class FlashMistralBatch(FlashCausalLMBatch): ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device + next_token_chooser_parameters, dtype, device, tokenizer ) start_slots = torch.tensor(start_slots, dtype=torch.int64) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 42ff1c80..a2c30255 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): requests_idx_mapping[r.id] = i # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 2f28688d..1f633f8a 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -114,7 +114,7 @@ class IdeficsCausalLMBatch(Batch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index f2e4cec6..d506a46a 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -96,7 +96,7 @@ class Seq2SeqLMBatch(Batch): inputs.append(r.inputs) requests_idx_mapping[r.id] = i decoder_input_lengths.append(1) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index d5adbd32..847ec583 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -143,7 +143,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if len(batches) > 1: start_concat = time.time_ns() - batch = self.model.batch_type.concatenate(batches) + batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer) concat_ns = time.time_ns() - start_concat else: batch = batches[0] diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 04cc8d97..998e1df3 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -15,6 +15,8 @@ from text_generation_server.utils.logits_process import ( ) from text_generation_server.utils.watermark import WatermarkLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor +from transformers_cfg.grammar_utils import IncrementalGrammarConstraint +from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor class NextTokenChooser: @@ -29,6 +31,9 @@ class NextTokenChooser: do_sample=False, seed=0, device="cpu", + tokenizer=None, + use_grammar_constraint=False, + grammar="", ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -39,6 +44,12 @@ class NextTokenChooser: else None ) + if use_grammar_constraint: + grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer) + self.grammar_processor = GrammarConstrainedLogitsProcessor(grammar) + else: + self.grammar_processor = None + has_warpers = ( (temperature is not None and temperature != 1.0) or (top_k is not None and top_k != 0) @@ -60,6 +71,8 @@ class NextTokenChooser: scores = self.watermark_processor(input_ids, scores) if self.repetition_processor is not None: scores = self.repetition_processor(input_ids, scores) + if self.grammar_processor is not None: + scores = self.grammar_processor(input_ids, scores) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) @@ -75,6 +88,7 @@ class NextTokenChooser: cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device, + tokenizer: PreTrainedTokenizerBase, ) -> "NextTokenChooser": return NextTokenChooser( watermark=pb.watermark, @@ -86,6 +100,9 @@ class NextTokenChooser: do_sample=pb.do_sample, seed=pb.seed, device=device, + use_grammar_constraint=pb.use_grammar_constraint, + grammar=pb.grammar, + tokenizer=tokenizer, ) @@ -181,7 +198,10 @@ class HeterogeneousNextTokenChooser: self, dtype: torch.dtype, device: torch.device, + tokenizer: PreTrainedTokenizerBase, watermark: List[bool], + use_grammar_constraint: List[bool], + grammar: List[str], temperature: List[float], repetition_penalty: List[float], top_k: List[int], @@ -212,6 +232,11 @@ class HeterogeneousNextTokenChooser: else None ) + if use_grammar_constraint: + grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer) + grammar_processor = GrammarConstrainedLogitsProcessor(grammar) + warpers.append(grammar_processor) + if any([x != 1.0 for x in temperature]): do_sample = [ sample or x != 1.0 for x, sample in zip(temperature, do_sample) @@ -359,6 +384,7 @@ class HeterogeneousNextTokenChooser: pb: List[generate_pb2.NextTokenChooserParameters], dtype: torch.dtype, device: torch.device, + tokenizer: PreTrainedTokenizerBase, ) -> "HeterogeneousNextTokenChooser": return HeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], @@ -371,6 +397,9 @@ class HeterogeneousNextTokenChooser: seeds=[pb_.seed for pb_ in pb], device=device, dtype=dtype, + tokenizer=tokenizer, + use_grammar_constraint=use_grammar_constraint, + grammar=grammar, )