mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Adding CFG (context free grammar) to TGI
This commit is contained in:
parent
d077150eb7
commit
e749d0cc5a
@ -1,3 +1,7 @@
|
|||||||
|
## Warning
|
||||||
|
|
||||||
|
This is a personal copy of [huggingface/text-generation-inference](https://github.com/huggingface/text-generation-inference). It contains experimental changes that are not meant for any serious use.
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
<a href="https://www.youtube.com/watch?v=jlMAX2Oaht0">
|
<a href="https://www.youtube.com/watch?v=jlMAX2Oaht0">
|
||||||
|
@ -68,6 +68,8 @@ message NextTokenChooserParameters {
|
|||||||
float repetition_penalty = 7;
|
float repetition_penalty = 7;
|
||||||
/// token watermarking using "A Watermark for Large Language Models"
|
/// token watermarking using "A Watermark for Large Language Models"
|
||||||
bool watermark = 8;
|
bool watermark = 8;
|
||||||
|
bool use_grammar_constraint = 9;
|
||||||
|
string grammar = 10;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoppingCriteriaParameters {
|
message StoppingCriteriaParameters {
|
||||||
|
@ -126,6 +126,8 @@ impl Client {
|
|||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 1.2,
|
repetition_penalty: 1.2,
|
||||||
watermark: true,
|
watermark: true,
|
||||||
|
use_grammar_constraint: false,
|
||||||
|
grammar: "".to_string(),
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: max_total_tokens - truncate,
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
|
@ -44,6 +44,8 @@ impl Health {
|
|||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 1.0,
|
repetition_penalty: 1.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
|
use_grammar_constraint: false,
|
||||||
|
grammar: "".to_string(),
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
@ -123,6 +123,12 @@ pub(crate) struct GenerateParameters {
|
|||||||
pub watermark: bool,
|
pub watermark: bool,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "true")]
|
#[schema(default = "true")]
|
||||||
|
pub use_grammar_constraint: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(default = "false")]
|
||||||
|
pub grammar: String,
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(default = "")]
|
||||||
pub details: bool,
|
pub details: bool,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "true")]
|
#[schema(default = "true")]
|
||||||
@ -158,6 +164,8 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
truncate: None,
|
truncate: None,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
|
use_grammar_constraint: false,
|
||||||
|
grammar: "".to_string(),
|
||||||
details: false,
|
details: false,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
|
@ -163,6 +163,8 @@ impl Validation {
|
|||||||
truncate,
|
truncate,
|
||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
|
use_grammar_constraint,
|
||||||
|
grammar,
|
||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
..
|
..
|
||||||
@ -279,6 +281,8 @@ impl Validation {
|
|||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
|
use_grammar_constraint,
|
||||||
|
grammar,
|
||||||
};
|
};
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
@ -39,6 +39,7 @@ setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
git+https://github.com/epfl-dlab/transformers_cfg.git@4d764337d4d0c167a32560e601fc7a5fb9f33e85
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -87,7 +87,7 @@ class CausalLMBatch(Batch):
|
|||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
@ -258,7 +258,7 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
|
def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "CausalLMBatch":
|
||||||
# Used for padding
|
# Used for padding
|
||||||
total_batch_size = 0
|
total_batch_size = 0
|
||||||
max_input_length = 0
|
max_input_length = 0
|
||||||
|
@ -233,7 +233,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device
|
next_token_chooser_parameters, dtype, device, tokenizer
|
||||||
)
|
)
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
@ -468,7 +468,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "FlashCausalLMBatch":
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
@ -589,6 +589,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
next_token_chooser_parameters,
|
next_token_chooser_parameters,
|
||||||
dtype=batches[0].next_token_chooser.dtype,
|
dtype=batches[0].next_token_chooser.dtype,
|
||||||
device=batches[0].next_token_chooser.device,
|
device=batches[0].next_token_chooser.device,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_ids = (
|
speculative_ids = (
|
||||||
|
@ -190,7 +190,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device
|
next_token_chooser_parameters, dtype, device, tokenizer
|
||||||
)
|
)
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
@ -114,7 +114,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
@ -96,7 +96,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
@ -143,7 +143,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
start_concat = time.time_ns()
|
start_concat = time.time_ns()
|
||||||
batch = self.model.batch_type.concatenate(batches)
|
batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer)
|
||||||
concat_ns = time.time_ns() - start_concat
|
concat_ns = time.time_ns() - start_concat
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
@ -15,6 +15,8 @@ from text_generation_server.utils.logits_process import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||||
|
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
|
||||||
|
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
@ -29,6 +31,9 @@ class NextTokenChooser:
|
|||||||
do_sample=False,
|
do_sample=False,
|
||||||
seed=0,
|
seed=0,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
|
tokenizer=None,
|
||||||
|
use_grammar_constraint=False,
|
||||||
|
grammar="",
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
@ -39,6 +44,12 @@ class NextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if use_grammar_constraint:
|
||||||
|
grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer)
|
||||||
|
self.grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
|
||||||
|
else:
|
||||||
|
self.grammar_processor = None
|
||||||
|
|
||||||
has_warpers = (
|
has_warpers = (
|
||||||
(temperature is not None and temperature != 1.0)
|
(temperature is not None and temperature != 1.0)
|
||||||
or (top_k is not None and top_k != 0)
|
or (top_k is not None and top_k != 0)
|
||||||
@ -60,6 +71,8 @@ class NextTokenChooser:
|
|||||||
scores = self.watermark_processor(input_ids, scores)
|
scores = self.watermark_processor(input_ids, scores)
|
||||||
if self.repetition_processor is not None:
|
if self.repetition_processor is not None:
|
||||||
scores = self.repetition_processor(input_ids, scores)
|
scores = self.repetition_processor(input_ids, scores)
|
||||||
|
if self.grammar_processor is not None:
|
||||||
|
scores = self.grammar_processor(input_ids, scores)
|
||||||
|
|
||||||
if self.static_warper is None:
|
if self.static_warper is None:
|
||||||
next_logprob = torch.log_softmax(scores, -1)
|
next_logprob = torch.log_softmax(scores, -1)
|
||||||
@ -75,6 +88,7 @@ class NextTokenChooser:
|
|||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.NextTokenChooserParameters,
|
pb: generate_pb2.NextTokenChooserParameters,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> "NextTokenChooser":
|
) -> "NextTokenChooser":
|
||||||
return NextTokenChooser(
|
return NextTokenChooser(
|
||||||
watermark=pb.watermark,
|
watermark=pb.watermark,
|
||||||
@ -86,6 +100,9 @@ class NextTokenChooser:
|
|||||||
do_sample=pb.do_sample,
|
do_sample=pb.do_sample,
|
||||||
seed=pb.seed,
|
seed=pb.seed,
|
||||||
device=device,
|
device=device,
|
||||||
|
use_grammar_constraint=pb.use_grammar_constraint,
|
||||||
|
grammar=pb.grammar,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -181,7 +198,10 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self,
|
self,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
watermark: List[bool],
|
watermark: List[bool],
|
||||||
|
use_grammar_constraint: List[bool],
|
||||||
|
grammar: List[str],
|
||||||
temperature: List[float],
|
temperature: List[float],
|
||||||
repetition_penalty: List[float],
|
repetition_penalty: List[float],
|
||||||
top_k: List[int],
|
top_k: List[int],
|
||||||
@ -212,6 +232,11 @@ class HeterogeneousNextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if use_grammar_constraint:
|
||||||
|
grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer)
|
||||||
|
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
|
||||||
|
warpers.append(grammar_processor)
|
||||||
|
|
||||||
if any([x != 1.0 for x in temperature]):
|
if any([x != 1.0 for x in temperature]):
|
||||||
do_sample = [
|
do_sample = [
|
||||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||||
@ -359,6 +384,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
pb: List[generate_pb2.NextTokenChooserParameters],
|
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> "HeterogeneousNextTokenChooser":
|
) -> "HeterogeneousNextTokenChooser":
|
||||||
return HeterogeneousNextTokenChooser(
|
return HeterogeneousNextTokenChooser(
|
||||||
watermark=[pb_.watermark for pb_ in pb],
|
watermark=[pb_.watermark for pb_ in pb],
|
||||||
@ -371,6 +397,9 @@ class HeterogeneousNextTokenChooser:
|
|||||||
seeds=[pb_.seed for pb_ in pb],
|
seeds=[pb_.seed for pb_ in pb],
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
use_grammar_constraint=use_grammar_constraint,
|
||||||
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user