Adding CFG (context free grammar) to TGI

This commit is contained in:
Łukasz Olszewski 2023-12-20 12:42:56 +01:00
parent d077150eb7
commit e749d0cc5a
15 changed files with 62 additions and 9 deletions

View File

@ -1,3 +1,7 @@
## Warning
This is a personal copy of [huggingface/text-generation-inference](https://github.com/huggingface/text-generation-inference). It contains experimental changes that are not meant for any serious use.
<div align="center"> <div align="center">
<a href="https://www.youtube.com/watch?v=jlMAX2Oaht0"> <a href="https://www.youtube.com/watch?v=jlMAX2Oaht0">

View File

@ -68,6 +68,8 @@ message NextTokenChooserParameters {
float repetition_penalty = 7; float repetition_penalty = 7;
/// token watermarking using "A Watermark for Large Language Models" /// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8; bool watermark = 8;
bool use_grammar_constraint = 9;
string grammar = 10;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {

View File

@ -126,6 +126,8 @@ impl Client {
seed: 0, seed: 0,
repetition_penalty: 1.2, repetition_penalty: 1.2,
watermark: true, watermark: true,
use_grammar_constraint: false,
grammar: "".to_string(),
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate, max_new_tokens: max_total_tokens - truncate,

View File

@ -44,6 +44,8 @@ impl Health {
seed: 0, seed: 0,
repetition_penalty: 1.0, repetition_penalty: 1.0,
watermark: false, watermark: false,
use_grammar_constraint: false,
grammar: "".to_string(),
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1, max_new_tokens: 1,

View File

@ -123,6 +123,12 @@ pub(crate) struct GenerateParameters {
pub watermark: bool, pub watermark: bool,
#[serde(default)] #[serde(default)]
#[schema(default = "true")] #[schema(default = "true")]
pub use_grammar_constraint: bool,
#[serde(default)]
#[schema(default = "false")]
pub grammar: String,
#[serde(default)]
#[schema(default = "")]
pub details: bool, pub details: bool,
#[serde(default)] #[serde(default)]
#[schema(default = "true")] #[schema(default = "true")]
@ -158,6 +164,8 @@ fn default_parameters() -> GenerateParameters {
stop: Vec::new(), stop: Vec::new(),
truncate: None, truncate: None,
watermark: false, watermark: false,
use_grammar_constraint: false,
grammar: "".to_string(),
details: false, details: false,
decoder_input_details: false, decoder_input_details: false,
seed: None, seed: None,

View File

@ -163,6 +163,8 @@ impl Validation {
truncate, truncate,
seed, seed,
watermark, watermark,
use_grammar_constraint,
grammar,
decoder_input_details, decoder_input_details,
top_n_tokens, top_n_tokens,
.. ..
@ -279,6 +281,8 @@ impl Validation {
do_sample, do_sample,
seed, seed,
watermark, watermark,
use_grammar_constraint,
grammar,
}; };
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,

View File

@ -39,6 +39,7 @@ setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13" transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
git+https://github.com/epfl-dlab/transformers_cfg.git@4d764337d4d0c167a32560e601fc7a5fb9f33e85
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -87,7 +87,7 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
@ -258,7 +258,7 @@ class CausalLMBatch(Batch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "CausalLMBatch":
# Used for padding # Used for padding
total_batch_size = 0 total_batch_size = 0
max_input_length = 0 max_input_length = 0

View File

@ -233,7 +233,7 @@ class FlashCausalLMBatch(Batch):
) )
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device, tokenizer
) )
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -468,7 +468,7 @@ class FlashCausalLMBatch(Batch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "FlashCausalLMBatch":
# Batch attributes # Batch attributes
requests = [] requests = []
requests_idx_mapping = {} requests_idx_mapping = {}
@ -589,6 +589,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters, next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype, dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device, device=batches[0].next_token_chooser.device,
tokenizer=tokenizer,
) )
speculative_ids = ( speculative_ids = (

View File

@ -190,7 +190,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
) )
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device, tokenizer
) )
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)

View File

@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(escape_custom_split_sequence(r.inputs))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )

View File

@ -114,7 +114,7 @@ class IdeficsCausalLMBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )

View File

@ -96,7 +96,7 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs) inputs.append(r.inputs)
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )

View File

@ -143,7 +143,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) > 1: if len(batches) > 1:
start_concat = time.time_ns() start_concat = time.time_ns()
batch = self.model.batch_type.concatenate(batches) batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer)
concat_ns = time.time_ns() - start_concat concat_ns = time.time_ns() - start_concat
else: else:
batch = batches[0] batch = batches[0]

View File

@ -15,6 +15,8 @@ from text_generation_server.utils.logits_process import (
) )
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
class NextTokenChooser: class NextTokenChooser:
@ -29,6 +31,9 @@ class NextTokenChooser:
do_sample=False, do_sample=False,
seed=0, seed=0,
device="cpu", device="cpu",
tokenizer=None,
use_grammar_constraint=False,
grammar="",
): ):
self.watermark_processor = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
@ -39,6 +44,12 @@ class NextTokenChooser:
else None else None
) )
if use_grammar_constraint:
grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer)
self.grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
else:
self.grammar_processor = None
has_warpers = ( has_warpers = (
(temperature is not None and temperature != 1.0) (temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0) or (top_k is not None and top_k != 0)
@ -60,6 +71,8 @@ class NextTokenChooser:
scores = self.watermark_processor(input_ids, scores) scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None: if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores) scores = self.repetition_processor(input_ids, scores)
if self.grammar_processor is not None:
scores = self.grammar_processor(input_ids, scores)
if self.static_warper is None: if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1) next_logprob = torch.log_softmax(scores, -1)
@ -75,6 +88,7 @@ class NextTokenChooser:
cls, cls,
pb: generate_pb2.NextTokenChooserParameters, pb: generate_pb2.NextTokenChooserParameters,
device: torch.device, device: torch.device,
tokenizer: PreTrainedTokenizerBase,
) -> "NextTokenChooser": ) -> "NextTokenChooser":
return NextTokenChooser( return NextTokenChooser(
watermark=pb.watermark, watermark=pb.watermark,
@ -86,6 +100,9 @@ class NextTokenChooser:
do_sample=pb.do_sample, do_sample=pb.do_sample,
seed=pb.seed, seed=pb.seed,
device=device, device=device,
use_grammar_constraint=pb.use_grammar_constraint,
grammar=pb.grammar,
tokenizer=tokenizer,
) )
@ -181,7 +198,10 @@ class HeterogeneousNextTokenChooser:
self, self,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
tokenizer: PreTrainedTokenizerBase,
watermark: List[bool], watermark: List[bool],
use_grammar_constraint: List[bool],
grammar: List[str],
temperature: List[float], temperature: List[float],
repetition_penalty: List[float], repetition_penalty: List[float],
top_k: List[int], top_k: List[int],
@ -212,6 +232,11 @@ class HeterogeneousNextTokenChooser:
else None else None
) )
if use_grammar_constraint:
grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
warpers.append(grammar_processor)
if any([x != 1.0 for x in temperature]): if any([x != 1.0 for x in temperature]):
do_sample = [ do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample) sample or x != 1.0 for x, sample in zip(temperature, do_sample)
@ -359,6 +384,7 @@ class HeterogeneousNextTokenChooser:
pb: List[generate_pb2.NextTokenChooserParameters], pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
tokenizer: PreTrainedTokenizerBase,
) -> "HeterogeneousNextTokenChooser": ) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser( return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb], watermark=[pb_.watermark for pb_ in pb],
@ -371,6 +397,9 @@ class HeterogeneousNextTokenChooser:
seeds=[pb_.seed for pb_ in pb], seeds=[pb_.seed for pb_ in pb],
device=device, device=device,
dtype=dtype, dtype=dtype,
tokenizer=tokenizer,
use_grammar_constraint=use_grammar_constraint,
grammar=grammar,
) )