mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Adding CFG (context free grammar) to TGI
This commit is contained in:
parent
d077150eb7
commit
e749d0cc5a
@ -1,3 +1,7 @@
|
||||
## Warning
|
||||
|
||||
This is a personal copy of [huggingface/text-generation-inference](https://github.com/huggingface/text-generation-inference). It contains experimental changes that are not meant for any serious use.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<a href="https://www.youtube.com/watch?v=jlMAX2Oaht0">
|
||||
|
@ -68,6 +68,8 @@ message NextTokenChooserParameters {
|
||||
float repetition_penalty = 7;
|
||||
/// token watermarking using "A Watermark for Large Language Models"
|
||||
bool watermark = 8;
|
||||
bool use_grammar_constraint = 9;
|
||||
string grammar = 10;
|
||||
}
|
||||
|
||||
message StoppingCriteriaParameters {
|
||||
|
@ -126,6 +126,8 @@ impl Client {
|
||||
seed: 0,
|
||||
repetition_penalty: 1.2,
|
||||
watermark: true,
|
||||
use_grammar_constraint: false,
|
||||
grammar: "".to_string(),
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: max_total_tokens - truncate,
|
||||
|
@ -44,6 +44,8 @@ impl Health {
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
watermark: false,
|
||||
use_grammar_constraint: false,
|
||||
grammar: "".to_string(),
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
|
@ -123,6 +123,12 @@ pub(crate) struct GenerateParameters {
|
||||
pub watermark: bool,
|
||||
#[serde(default)]
|
||||
#[schema(default = "true")]
|
||||
pub use_grammar_constraint: bool,
|
||||
#[serde(default)]
|
||||
#[schema(default = "false")]
|
||||
pub grammar: String,
|
||||
#[serde(default)]
|
||||
#[schema(default = "")]
|
||||
pub details: bool,
|
||||
#[serde(default)]
|
||||
#[schema(default = "true")]
|
||||
@ -158,6 +164,8 @@ fn default_parameters() -> GenerateParameters {
|
||||
stop: Vec::new(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
use_grammar_constraint: false,
|
||||
grammar: "".to_string(),
|
||||
details: false,
|
||||
decoder_input_details: false,
|
||||
seed: None,
|
||||
|
@ -163,6 +163,8 @@ impl Validation {
|
||||
truncate,
|
||||
seed,
|
||||
watermark,
|
||||
use_grammar_constraint,
|
||||
grammar,
|
||||
decoder_input_details,
|
||||
top_n_tokens,
|
||||
..
|
||||
@ -279,6 +281,8 @@ impl Validation {
|
||||
do_sample,
|
||||
seed,
|
||||
watermark,
|
||||
use_grammar_constraint,
|
||||
grammar,
|
||||
};
|
||||
let stopping_parameters = StoppingCriteriaParameters {
|
||||
max_new_tokens,
|
||||
|
@ -39,6 +39,7 @@ setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
git+https://github.com/epfl-dlab/transformers_cfg.git@4d764337d4d0c167a32560e601fc7a5fb9f33e85
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -87,7 +87,7 @@ class CausalLMBatch(Batch):
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
@ -258,7 +258,7 @@ class CausalLMBatch(Batch):
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
|
||||
def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "CausalLMBatch":
|
||||
# Used for padding
|
||||
total_batch_size = 0
|
||||
max_input_length = 0
|
||||
|
@ -233,7 +233,7 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype, device
|
||||
next_token_chooser_parameters, dtype, device, tokenizer
|
||||
)
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
@ -468,7 +468,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||
def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "FlashCausalLMBatch":
|
||||
# Batch attributes
|
||||
requests = []
|
||||
requests_idx_mapping = {}
|
||||
@ -589,6 +589,7 @@ class FlashCausalLMBatch(Batch):
|
||||
next_token_chooser_parameters,
|
||||
dtype=batches[0].next_token_chooser.dtype,
|
||||
device=batches[0].next_token_chooser.device,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
speculative_ids = (
|
||||
|
@ -190,7 +190,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype, device
|
||||
next_token_chooser_parameters, dtype, device, tokenizer
|
||||
)
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
|
@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
requests_idx_mapping[r.id] = i
|
||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
@ -114,7 +114,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
@ -96,7 +96,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
inputs.append(r.inputs)
|
||||
requests_idx_mapping[r.id] = i
|
||||
decoder_input_lengths.append(1)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
|
@ -143,7 +143,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
if len(batches) > 1:
|
||||
start_concat = time.time_ns()
|
||||
batch = self.model.batch_type.concatenate(batches)
|
||||
batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer)
|
||||
concat_ns = time.time_ns() - start_concat
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
@ -15,6 +15,8 @@ from text_generation_server.utils.logits_process import (
|
||||
)
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
|
||||
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
|
||||
|
||||
|
||||
class NextTokenChooser:
|
||||
@ -29,6 +31,9 @@ class NextTokenChooser:
|
||||
do_sample=False,
|
||||
seed=0,
|
||||
device="cpu",
|
||||
tokenizer=None,
|
||||
use_grammar_constraint=False,
|
||||
grammar="",
|
||||
):
|
||||
self.watermark_processor = (
|
||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||
@ -39,6 +44,12 @@ class NextTokenChooser:
|
||||
else None
|
||||
)
|
||||
|
||||
if use_grammar_constraint:
|
||||
grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer)
|
||||
self.grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
|
||||
else:
|
||||
self.grammar_processor = None
|
||||
|
||||
has_warpers = (
|
||||
(temperature is not None and temperature != 1.0)
|
||||
or (top_k is not None and top_k != 0)
|
||||
@ -60,6 +71,8 @@ class NextTokenChooser:
|
||||
scores = self.watermark_processor(input_ids, scores)
|
||||
if self.repetition_processor is not None:
|
||||
scores = self.repetition_processor(input_ids, scores)
|
||||
if self.grammar_processor is not None:
|
||||
scores = self.grammar_processor(input_ids, scores)
|
||||
|
||||
if self.static_warper is None:
|
||||
next_logprob = torch.log_softmax(scores, -1)
|
||||
@ -75,6 +88,7 @@ class NextTokenChooser:
|
||||
cls,
|
||||
pb: generate_pb2.NextTokenChooserParameters,
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> "NextTokenChooser":
|
||||
return NextTokenChooser(
|
||||
watermark=pb.watermark,
|
||||
@ -86,6 +100,9 @@ class NextTokenChooser:
|
||||
do_sample=pb.do_sample,
|
||||
seed=pb.seed,
|
||||
device=device,
|
||||
use_grammar_constraint=pb.use_grammar_constraint,
|
||||
grammar=pb.grammar,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
|
||||
@ -181,7 +198,10 @@ class HeterogeneousNextTokenChooser:
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
watermark: List[bool],
|
||||
use_grammar_constraint: List[bool],
|
||||
grammar: List[str],
|
||||
temperature: List[float],
|
||||
repetition_penalty: List[float],
|
||||
top_k: List[int],
|
||||
@ -212,6 +232,11 @@ class HeterogeneousNextTokenChooser:
|
||||
else None
|
||||
)
|
||||
|
||||
if use_grammar_constraint:
|
||||
grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer)
|
||||
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
|
||||
warpers.append(grammar_processor)
|
||||
|
||||
if any([x != 1.0 for x in temperature]):
|
||||
do_sample = [
|
||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||
@ -359,6 +384,7 @@ class HeterogeneousNextTokenChooser:
|
||||
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> "HeterogeneousNextTokenChooser":
|
||||
return HeterogeneousNextTokenChooser(
|
||||
watermark=[pb_.watermark for pb_ in pb],
|
||||
@ -371,6 +397,9 @@ class HeterogeneousNextTokenChooser:
|
||||
seeds=[pb_.seed for pb_ in pb],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
tokenizer=tokenizer,
|
||||
use_grammar_constraint=use_grammar_constraint,
|
||||
grammar=grammar,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user