Adding CFG (context free grammar) to TGI

This commit is contained in:
Łukasz Olszewski 2023-12-20 12:42:56 +01:00
parent d077150eb7
commit e749d0cc5a
15 changed files with 62 additions and 9 deletions

View File

@ -1,3 +1,7 @@
## Warning
This is a personal copy of [huggingface/text-generation-inference](https://github.com/huggingface/text-generation-inference). It contains experimental changes that are not meant for any serious use.
<div align="center">
<a href="https://www.youtube.com/watch?v=jlMAX2Oaht0">

View File

@ -68,6 +68,8 @@ message NextTokenChooserParameters {
float repetition_penalty = 7;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
bool use_grammar_constraint = 9;
string grammar = 10;
}
message StoppingCriteriaParameters {

View File

@ -126,6 +126,8 @@ impl Client {
seed: 0,
repetition_penalty: 1.2,
watermark: true,
use_grammar_constraint: false,
grammar: "".to_string(),
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,

View File

@ -44,6 +44,8 @@ impl Health {
seed: 0,
repetition_penalty: 1.0,
watermark: false,
use_grammar_constraint: false,
grammar: "".to_string(),
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,

View File

@ -123,6 +123,12 @@ pub(crate) struct GenerateParameters {
pub watermark: bool,
#[serde(default)]
#[schema(default = "true")]
pub use_grammar_constraint: bool,
#[serde(default)]
#[schema(default = "false")]
pub grammar: String,
#[serde(default)]
#[schema(default = "")]
pub details: bool,
#[serde(default)]
#[schema(default = "true")]
@ -158,6 +164,8 @@ fn default_parameters() -> GenerateParameters {
stop: Vec::new(),
truncate: None,
watermark: false,
use_grammar_constraint: false,
grammar: "".to_string(),
details: false,
decoder_input_details: false,
seed: None,

View File

@ -163,6 +163,8 @@ impl Validation {
truncate,
seed,
watermark,
use_grammar_constraint,
grammar,
decoder_input_details,
top_n_tokens,
..
@ -279,6 +281,8 @@ impl Validation {
do_sample,
seed,
watermark,
use_grammar_constraint,
grammar,
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,

View File

@ -39,6 +39,7 @@ setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
git+https://github.com/epfl-dlab/transformers_cfg.git@4d764337d4d0c167a32560e601fc7a5fb9f33e85
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -87,7 +87,7 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
@ -258,7 +258,7 @@ class CausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "CausalLMBatch":
# Used for padding
total_batch_size = 0
max_input_length = 0

View File

@ -233,7 +233,7 @@ class FlashCausalLMBatch(Batch):
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
next_token_chooser_parameters, dtype, device, tokenizer
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -468,7 +468,7 @@ class FlashCausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "FlashCausalLMBatch":
# Batch attributes
requests = []
requests_idx_mapping = {}
@ -589,6 +589,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device,
tokenizer=tokenizer,
)
speculative_ids = (

View File

@ -190,7 +190,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
next_token_chooser_parameters, dtype, device, tokenizer
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)

View File

@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)

View File

@ -114,7 +114,7 @@ class IdeficsCausalLMBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)

View File

@ -96,7 +96,7 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs)
requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)

View File

@ -143,7 +143,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) > 1:
start_concat = time.time_ns()
batch = self.model.batch_type.concatenate(batches)
batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer)
concat_ns = time.time_ns() - start_concat
else:
batch = batches[0]

View File

@ -15,6 +15,8 @@ from text_generation_server.utils.logits_process import (
)
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
class NextTokenChooser:
@ -29,6 +31,9 @@ class NextTokenChooser:
do_sample=False,
seed=0,
device="cpu",
tokenizer=None,
use_grammar_constraint=False,
grammar="",
):
self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None
@ -39,6 +44,12 @@ class NextTokenChooser:
else None
)
if use_grammar_constraint:
grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer)
self.grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
else:
self.grammar_processor = None
has_warpers = (
(temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0)
@ -60,6 +71,8 @@ class NextTokenChooser:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
if self.grammar_processor is not None:
scores = self.grammar_processor(input_ids, scores)
if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1)
@ -75,6 +88,7 @@ class NextTokenChooser:
cls,
pb: generate_pb2.NextTokenChooserParameters,
device: torch.device,
tokenizer: PreTrainedTokenizerBase,
) -> "NextTokenChooser":
return NextTokenChooser(
watermark=pb.watermark,
@ -86,6 +100,9 @@ class NextTokenChooser:
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
use_grammar_constraint=pb.use_grammar_constraint,
grammar=pb.grammar,
tokenizer=tokenizer,
)
@ -181,7 +198,10 @@ class HeterogeneousNextTokenChooser:
self,
dtype: torch.dtype,
device: torch.device,
tokenizer: PreTrainedTokenizerBase,
watermark: List[bool],
use_grammar_constraint: List[bool],
grammar: List[str],
temperature: List[float],
repetition_penalty: List[float],
top_k: List[int],
@ -212,6 +232,11 @@ class HeterogeneousNextTokenChooser:
else None
)
if use_grammar_constraint:
grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
warpers.append(grammar_processor)
if any([x != 1.0 for x in temperature]):
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
@ -359,6 +384,7 @@ class HeterogeneousNextTokenChooser:
pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype,
device: torch.device,
tokenizer: PreTrainedTokenizerBase,
) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb],
@ -371,6 +397,9 @@ class HeterogeneousNextTokenChooser:
seeds=[pb_.seed for pb_ in pb],
device=device,
dtype=dtype,
tokenizer=tokenizer,
use_grammar_constraint=use_grammar_constraint,
grammar=grammar,
)