diff --git a/README.md b/README.md
index d99b7306..22670987 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,7 @@
+## Warning
+
+This is a personal copy of [huggingface/text-generation-inference](https://github.com/huggingface/text-generation-inference). It contains experimental changes that are not meant for any serious use.
+
diff --git a/proto/generate.proto b/proto/generate.proto
index 1f30df38..13cac7a3 100644
--- a/proto/generate.proto
+++ b/proto/generate.proto
@@ -68,6 +68,8 @@ message NextTokenChooserParameters {
float repetition_penalty = 7;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
+ bool use_grammar_constraint = 9;
+ string grammar = 10;
}
message StoppingCriteriaParameters {
diff --git a/router/client/src/client.rs b/router/client/src/client.rs
index 4723d664..22cf6c11 100644
--- a/router/client/src/client.rs
+++ b/router/client/src/client.rs
@@ -126,6 +126,8 @@ impl Client {
seed: 0,
repetition_penalty: 1.2,
watermark: true,
+ use_grammar_constraint: false,
+ grammar: "".to_string(),
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
diff --git a/router/src/health.rs b/router/src/health.rs
index ab290fc1..78c94017 100644
--- a/router/src/health.rs
+++ b/router/src/health.rs
@@ -44,6 +44,8 @@ impl Health {
seed: 0,
repetition_penalty: 1.0,
watermark: false,
+ use_grammar_constraint: false,
+ grammar: "".to_string(),
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
diff --git a/router/src/lib.rs b/router/src/lib.rs
index 898fcd04..8e19605c 100644
--- a/router/src/lib.rs
+++ b/router/src/lib.rs
@@ -123,6 +123,12 @@ pub(crate) struct GenerateParameters {
pub watermark: bool,
#[serde(default)]
#[schema(default = "true")]
+ pub use_grammar_constraint: bool,
+ #[serde(default)]
+ #[schema(default = "false")]
+ pub grammar: String,
+ #[serde(default)]
+ #[schema(default = "")]
pub details: bool,
#[serde(default)]
#[schema(default = "true")]
@@ -158,6 +164,8 @@ fn default_parameters() -> GenerateParameters {
stop: Vec::new(),
truncate: None,
watermark: false,
+ use_grammar_constraint: false,
+ grammar: "".to_string(),
details: false,
decoder_input_details: false,
seed: None,
diff --git a/router/src/validation.rs b/router/src/validation.rs
index 64f25c82..f5a28b94 100644
--- a/router/src/validation.rs
+++ b/router/src/validation.rs
@@ -163,6 +163,8 @@ impl Validation {
truncate,
seed,
watermark,
+ use_grammar_constraint,
+ grammar,
decoder_input_details,
top_n_tokens,
..
@@ -279,6 +281,8 @@ impl Validation {
do_sample,
seed,
watermark,
+ use_grammar_constraint,
+ grammar,
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,
diff --git a/server/requirements_common.txt b/server/requirements_common.txt
index 5a321834..ec17498e 100644
--- a/server/requirements_common.txt
+++ b/server/requirements_common.txt
@@ -39,6 +39,7 @@ setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
+git+https://github.com/epfl-dlab/transformers_cfg.git@4d764337d4d0c167a32560e601fc7a5fb9f33e85
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py
index 7b10256c..97b04b60 100644
--- a/server/text_generation_server/models/causal_lm.py
+++ b/server/text_generation_server/models/causal_lm.py
@@ -87,7 +87,7 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
- next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
+ next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
@@ -258,7 +258,7 @@ class CausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
- def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
+ def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "CausalLMBatch":
# Used for padding
total_batch_size = 0
max_input_length = 0
diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py
index 930082cd..cdf3803b 100644
--- a/server/text_generation_server/models/flash_causal_lm.py
+++ b/server/text_generation_server/models/flash_causal_lm.py
@@ -233,7 +233,7 @@ class FlashCausalLMBatch(Batch):
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
- next_token_chooser_parameters, dtype, device
+ next_token_chooser_parameters, dtype, device, tokenizer
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
@@ -468,7 +468,7 @@ class FlashCausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
- def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
+ def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "FlashCausalLMBatch":
# Batch attributes
requests = []
requests_idx_mapping = {}
@@ -589,6 +589,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device,
+ tokenizer=tokenizer,
)
speculative_ids = (
diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py
index 8c6cb025..18b98c66 100644
--- a/server/text_generation_server/models/flash_mistral.py
+++ b/server/text_generation_server/models/flash_mistral.py
@@ -190,7 +190,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
- next_token_chooser_parameters, dtype, device
+ next_token_chooser_parameters, dtype, device, tokenizer
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py
index 42ff1c80..a2c30255 100644
--- a/server/text_generation_server/models/galactica.py
+++ b/server/text_generation_server/models/galactica.py
@@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
- next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
+ next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py
index 2f28688d..1f633f8a 100644
--- a/server/text_generation_server/models/idefics_causal_lm.py
+++ b/server/text_generation_server/models/idefics_causal_lm.py
@@ -114,7 +114,7 @@ class IdeficsCausalLMBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
- next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
+ next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py
index f2e4cec6..d506a46a 100644
--- a/server/text_generation_server/models/seq2seq_lm.py
+++ b/server/text_generation_server/models/seq2seq_lm.py
@@ -96,7 +96,7 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs)
requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1)
- next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
+ next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py
index d5adbd32..847ec583 100644
--- a/server/text_generation_server/server.py
+++ b/server/text_generation_server/server.py
@@ -143,7 +143,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) > 1:
start_concat = time.time_ns()
- batch = self.model.batch_type.concatenate(batches)
+ batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer)
concat_ns = time.time_ns() - start_concat
else:
batch = batches[0]
diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py
index 04cc8d97..998e1df3 100644
--- a/server/text_generation_server/utils/tokens.py
+++ b/server/text_generation_server/utils/tokens.py
@@ -15,6 +15,8 @@ from text_generation_server.utils.logits_process import (
)
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
+from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
+from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
class NextTokenChooser:
@@ -29,6 +31,9 @@ class NextTokenChooser:
do_sample=False,
seed=0,
device="cpu",
+ tokenizer=None,
+ use_grammar_constraint=False,
+ grammar="",
):
self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None
@@ -39,6 +44,12 @@ class NextTokenChooser:
else None
)
+ if use_grammar_constraint:
+ grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer)
+ self.grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
+ else:
+ self.grammar_processor = None
+
has_warpers = (
(temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0)
@@ -60,6 +71,8 @@ class NextTokenChooser:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
+ if self.grammar_processor is not None:
+ scores = self.grammar_processor(input_ids, scores)
if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1)
@@ -75,6 +88,7 @@ class NextTokenChooser:
cls,
pb: generate_pb2.NextTokenChooserParameters,
device: torch.device,
+ tokenizer: PreTrainedTokenizerBase,
) -> "NextTokenChooser":
return NextTokenChooser(
watermark=pb.watermark,
@@ -86,6 +100,9 @@ class NextTokenChooser:
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
+ use_grammar_constraint=pb.use_grammar_constraint,
+ grammar=pb.grammar,
+ tokenizer=tokenizer,
)
@@ -181,7 +198,10 @@ class HeterogeneousNextTokenChooser:
self,
dtype: torch.dtype,
device: torch.device,
+ tokenizer: PreTrainedTokenizerBase,
watermark: List[bool],
+ use_grammar_constraint: List[bool],
+ grammar: List[str],
temperature: List[float],
repetition_penalty: List[float],
top_k: List[int],
@@ -212,6 +232,11 @@ class HeterogeneousNextTokenChooser:
else None
)
+ if use_grammar_constraint:
+ grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer)
+ grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
+ warpers.append(grammar_processor)
+
if any([x != 1.0 for x in temperature]):
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
@@ -359,6 +384,7 @@ class HeterogeneousNextTokenChooser:
pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype,
device: torch.device,
+ tokenizer: PreTrainedTokenizerBase,
) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb],
@@ -371,6 +397,9 @@ class HeterogeneousNextTokenChooser:
seeds=[pb_.seed for pb_ in pb],
device=device,
dtype=dtype,
+ tokenizer=tokenizer,
+ use_grammar_constraint=use_grammar_constraint,
+ grammar=grammar,
)