adding guidance and extra parameters for token bias

This commit is contained in:
Łukasz Olszewski 2023-12-22 11:14:06 +01:00
parent 1108560745
commit d84b38e30d
18 changed files with 147 additions and 48 deletions

View File

@ -51,6 +51,11 @@ message ClearCacheRequest {
/// Empty response
message ClearCacheResponse {}
message LogitBias {
string word = 1;
float bias = 2;
}
message NextTokenChooserParameters {
/// exponential scaling output probability distribution
float temperature = 1;
@ -70,6 +75,9 @@ message NextTokenChooserParameters {
bool watermark = 8;
bool use_grammar_constraint = 9;
string grammar = 10;
repeated LogitBias logit_bias = 11;
float guidance_scale = 12;
string negative_inputs = 13;
}
message StoppingCriteriaParameters {

View File

@ -128,6 +128,10 @@ impl Client {
watermark: true,
use_grammar_constraint: false,
grammar: "".to_string(),
logit_bias: Vec::new(),
guidance_scale: 1.0,
negative_inputs: "".to_string(),
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,

View File

@ -10,7 +10,7 @@ pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
Request, StoppingCriteriaParameters, Tokens,
Request, StoppingCriteriaParameters, Tokens,LogitBias
};
pub use sharded_client::ShardedClient;
use thiserror::Error;

View File

@ -46,6 +46,9 @@ impl Health {
watermark: false,
use_grammar_constraint: false,
grammar: "".to_string(),
logit_bias: Vec::new(),
guidance_scale: 1.0,
negative_inputs: "".to_string(),
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,

View File

@ -122,12 +122,21 @@ pub(crate) struct GenerateParameters {
#[schema(default = "false", example = true)]
pub watermark: bool,
#[serde(default)]
#[schema(default = "true")]
#[schema(default = "false")]
pub use_grammar_constraint: bool,
#[serde(default)]
#[schema(default = "false")]
#[schema(default = "")]
pub grammar: String,
#[serde(default)]
#[schema(default = "null", nullable = true)]
pub logit_bias: Vec<(String, f32)>,
#[serde(default)]
#[schema(default = "1.0")]
pub guidance_scale: f32,
#[serde(default)]
#[schema(default = "")]
pub negative_inputs: String,
#[serde(default)]
#[schema(default = "")]
pub details: bool,
#[serde(default)]
@ -166,6 +175,9 @@ fn default_parameters() -> GenerateParameters {
watermark: false,
use_grammar_constraint: false,
grammar: "".to_string(),
logit_bias: Vec::new(),
guidance_scale: 1.0,
negative_inputs: "".to_string(),
details: false,
decoder_input_details: false,
seed: None,

View File

@ -356,6 +356,11 @@ mod tests {
seed: 0,
repetition_penalty: 0.0,
watermark: false,
use_grammar_constraint: false,
grammar: "".to_string(),
logit_bias: Vec::new(),
guidance_scale: 1.0,
negative_inputs: "".to_string(),
},
stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false,

View File

@ -2,7 +2,7 @@
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest};
use rand::{thread_rng, Rng};
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters, LogitBias};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokenizers::TruncationDirection;
@ -165,6 +165,9 @@ impl Validation {
watermark,
use_grammar_constraint,
grammar,
guidance_scale,
negative_inputs,
logit_bias,
decoder_input_details,
top_n_tokens,
..
@ -268,10 +271,19 @@ impl Validation {
.unwrap_or(Ok(None))?;
// Validate inputs
let (inputs, input_length, max_new_tokens) = self
let (inputs, _input_length, max_new_tokens) = self
.validate_input(request.inputs, truncate, max_new_tokens)
.await?;
let (negative_inputs, input_length, max_new_tokens) = self
.validate_input(negative_inputs, truncate, Some(max_new_tokens))
.await?;
let logit_biases: Vec<LogitBias> = logit_bias
.into_iter()
.map(|(word, bias)| LogitBias { word, bias })
.collect();
let parameters = NextTokenChooserParameters {
temperature,
repetition_penalty,
@ -283,6 +295,9 @@ impl Validation {
watermark,
use_grammar_constraint,
grammar,
logit_bias: logit_biases,
guidance_scale,
negative_inputs,
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,

View File

@ -7,6 +7,7 @@ from transformers import (
AutoTokenizer,
AutoConfig,
PreTrainedTokenizerBase,
PreTrainedModel,
)
from text_generation_server.models.custom_modeling.bloom_modeling import (
@ -28,10 +29,11 @@ class BloomCausalLMBatch(CausalLMBatch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch = super().from_pb(pb=pb, tokenizer=tokenizer, model=model, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch
@ -101,7 +103,7 @@ class BLOOMSharded(CausalLM):
return BloomCausalLMBatch
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None #type: ignore
):
outputs = self.model.forward(
input_ids=input_ids,

View File

@ -3,7 +3,7 @@ import time
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, PreTrainedModel
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
@ -69,6 +69,7 @@ class CausalLMBatch(Batch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
@ -87,9 +88,9 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
r.stopping_parameters, tokenizer, model
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
@ -258,7 +259,7 @@ class CausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "CausalLMBatch":
def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "CausalLMBatch":
# Used for padding
total_batch_size = 0
max_input_length = 0
@ -545,7 +546,7 @@ class CausalLM(Model):
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None # type: ignore
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
kwargs = {

View File

@ -8,7 +8,7 @@ import numpy as np
from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from transformers import PreTrainedTokenizerBase, PreTrainedModel
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
@ -108,6 +108,7 @@ class FlashCausalLMBatch(Batch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
@ -181,7 +182,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters.append(r.parameters)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
r.stopping_parameters, tokenizer, model
)
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
@ -233,7 +234,7 @@ class FlashCausalLMBatch(Batch):
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer
next_token_chooser_parameters, dtype, device, tokenizer, model
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -468,7 +469,7 @@ class FlashCausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "FlashCausalLMBatch":
def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "FlashCausalLMBatch":
# Batch attributes
requests = []
requests_idx_mapping = {}
@ -590,6 +591,7 @@ class FlashCausalLMBatch(Batch):
dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device,
tokenizer=tokenizer,
model=model,
)
speculative_ids = (

View File

@ -6,7 +6,7 @@ import numpy as np
from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from transformers import PreTrainedTokenizerBase, PreTrainedModel
from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type, List
@ -48,6 +48,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
@ -124,7 +125,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
next_token_chooser_parameters.append(r.parameters)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
r.stopping_parameters, tokenizer, model
)
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
@ -190,7 +191,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer
next_token_chooser_parameters, dtype, device, tokenizer, model
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)

View File

@ -8,6 +8,7 @@ from transformers import (
AutoTokenizer,
AutoConfig,
PreTrainedTokenizerBase,
PreTrainedModel,
)
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
@ -73,6 +74,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype,
device: torch.device,
) -> "GalacticaCausalLMBatch":
@ -92,9 +94,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
r.stopping_parameters, tokenizer, model
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
@ -225,7 +227,7 @@ class GalacticaSharded(CausalLM):
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None #type: ignore
):
outputs = self.model.forward(
input_ids=input_ids,

View File

@ -7,6 +7,7 @@ from transformers import (
AutoProcessor,
AutoTokenizer,
PreTrainedTokenizerBase,
PreTrainedModel,
ProcessorMixin,
)
from typing import Optional, Tuple, List, Type, Dict
@ -96,6 +97,7 @@ class IdeficsCausalLMBatch(Batch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
processor: ProcessorMixin, # Hack
dtype: torch.dtype,
device: torch.device,
@ -114,9 +116,9 @@ class IdeficsCausalLMBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
r.stopping_parameters, tokenizer, model
)
stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate)
@ -642,7 +644,7 @@ class IdeficsCausalLM(Model):
pixel_values,
image_hidden_states,
image_attention_mask,
past_key_values: Optional = None,
past_key_values: Optional = None, #type: ignore
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
kwargs = {

View File

@ -4,7 +4,7 @@ import torch.distributed
from pathlib import Path
from typing import Optional, Type
from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel
from huggingface_hub import hf_hub_download
import json
@ -29,10 +29,11 @@ class MPTCausalLMBatch(CausalLMBatch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device, model=model)
batch.keys_head_dim_last = False
return batch

View File

@ -3,7 +3,7 @@ import time
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase, PreTrainedModel
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils.tokens import batch_top_tokens
@ -75,6 +75,7 @@ class Seq2SeqLMBatch(Batch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype,
device: torch.device,
) -> "Seq2SeqLMBatch":
@ -96,9 +97,9 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs)
requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
r.stopping_parameters, tokenizer, model
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
@ -278,7 +279,7 @@ class Seq2SeqLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
def concatenate(cls, batches: List["Seq2SeqLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "Seq2SeqLMBatch":
"""Concatenate multiple batches together by padding internal torch tensors"""
# Used for padding
@ -587,9 +588,9 @@ class Seq2SeqLM(Model):
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask: Optional,
encoder_last_hidden_state: Optional,
past_key_values: Optional = None,
decoder_attention_mask: Optional, #type: ignore
encoder_last_hidden_state: Optional, #type: ignore
past_key_values: Optional = None, #type: ignore
) -> Tuple[
torch.Tensor,
torch.Tensor,

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
from transformers import PreTrainedTokenizerBase
from transformers import PreTrainedTokenizerBase, PreTrainedModel
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
@ -22,6 +22,7 @@ class Batch(ABC):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype,
device: torch.device,
) -> "Batch":
@ -33,7 +34,7 @@ class Batch(ABC):
@classmethod
@abstractmethod
def concatenate(cls, batches: List["Batch"]) -> "Batch":
def concatenate(cls, batches: List["Batch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "Batch":
raise NotImplementedError
@abstractmethod

View File

@ -85,12 +85,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch,
self.model.tokenizer,
self.model.processor,
self.model,
self.model.dtype,
self.model.device,
)
else:
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
request.batch, self.model.tokenizer, self.model, self.model.dtype, self.model.device
)
max_supported_total_tokens = self.model.warmup(batch)
@ -107,12 +108,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch,
self.model.tokenizer,
self.model.processor,
self.model,
self.model.dtype,
self.model.device,
)
else:
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
request.batch, self.model.tokenizer, self.model, self.model.dtype, self.model.device
)
generations, next_batch, timings = self.model.generate_token(batch)
@ -143,7 +145,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) > 1:
start_concat = time.time_ns()
batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer)
batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer, model=self.model)
concat_ns = time.time_ns() - start_concat
else:
batch = batches[0]

View File

@ -14,7 +14,7 @@ from text_generation_server.utils.logits_process import (
static_warper,
)
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
from transformers import PreTrainedTokenizerBase,RepetitionPenaltyLogitsProcessor,UnbatchedClassifierFreeGuidanceLogitsProcessor,PreTrainedModel
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
@ -32,8 +32,11 @@ class NextTokenChooser:
seed=0,
device="cpu",
tokenizer=None,
model=None,
use_grammar_constraint=False,
grammar="",
guidance_scale=1.0,
negative_inputs="",
):
self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None
@ -50,6 +53,19 @@ class NextTokenChooser:
else:
self.grammar_processor = None
if guidance_scale != 1.0 and model is not None and negative_inputs:
negative_inputs_t = tokenizer([negative_inputs], return_tensors="pt")
device = next(model.model.parameters()).device
self.guidance_scale_processor = UnbatchedClassifierFreeGuidanceLogitsProcessor(
guidance_scale,
model.model,
unconditional_ids=negative_inputs_t["input_ids"].to(device),
unconditional_attention_mask=negative_inputs_t["attention_mask"].to(device),
use_cache=True, # use cache permanently on for now.
)
else:
self.guidance_scale_processor = None
has_warpers = (
(temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0)
@ -67,12 +83,15 @@ class NextTokenChooser:
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
if self.guidance_scale_processor is not None:
scores = self.guidance_scale_processor(input_ids, scores)
if self.grammar_processor is not None:
scores = self.grammar_processor(input_ids, scores)
if self.watermark_processor is not None:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
if self.grammar_processor is not None:
scores = self.grammar_processor(input_ids, scores)
if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1)
@ -89,6 +108,7 @@ class NextTokenChooser:
pb: generate_pb2.NextTokenChooserParameters,
device: torch.device,
tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
) -> "NextTokenChooser":
return NextTokenChooser(
watermark=pb.watermark,
@ -103,6 +123,9 @@ class NextTokenChooser:
use_grammar_constraint=pb.use_grammar_constraint,
grammar=pb.grammar,
tokenizer=tokenizer,
model=model,
guidance_scale=pb.guidance_scale,
negative_inputs=pb.negative_inputs,
)
@ -157,6 +180,7 @@ class StoppingCriteria:
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
@ -199,9 +223,12 @@ class HeterogeneousNextTokenChooser:
dtype: torch.dtype,
device: torch.device,
tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
watermark: List[bool],
use_grammar_constraint: List[bool],
grammar: List[str],
guidance_scale: List[float],
negative_inputs: List[str],
temperature: List[float],
repetition_penalty: List[float],
top_k: List[int],
@ -232,10 +259,14 @@ class HeterogeneousNextTokenChooser:
else None
)
if use_grammar_constraint:
grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
warpers.append(grammar_processor)
if any(use_grammar_constraint):
grammar_processors = {
i: GrammarConstrainedLogitsProcessor(IncrementalGrammarConstraint(grammar[i], "root", tokenizer))
for i, use_gc in enumerate(use_grammar_constraint) if use_gc
}
self.grammar_processor = HeterogeneousProcessorWrapper(grammar_processors)
else:
self.grammar_processor = None
if any([x != 1.0 for x in temperature]):
do_sample = [
@ -294,6 +325,8 @@ class HeterogeneousNextTokenChooser:
_scores = self.watermark_processor(input_ids, _scores)
if self.repetition_processor is not None:
_scores = self.repetition_processor(input_ids, _scores)
if self.grammar_processor is not None:
_scores = self.grammar_processor(input_ids, _scores)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
@ -385,6 +418,7 @@ class HeterogeneousNextTokenChooser:
dtype: torch.dtype,
device: torch.device,
tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb],
@ -398,8 +432,11 @@ class HeterogeneousNextTokenChooser:
device=device,
dtype=dtype,
tokenizer=tokenizer,
use_grammar_constraint=use_grammar_constraint,
grammar=grammar,
model=model,
use_grammar_constraint=[pb_.use_grammar_constraint for pb_ in pb],
grammar=[pb_.grammar for pb_ in pb],
guidance_scale=[pb_.guidance_scale for pb_ in pb],
negative_inputs=[pb_.negative_inputs for pb_ in pb],
)