diff --git a/proto/generate.proto b/proto/generate.proto index 13cac7a3..3ef1d56d 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -51,6 +51,11 @@ message ClearCacheRequest { /// Empty response message ClearCacheResponse {} +message LogitBias { + string word = 1; + float bias = 2; +} + message NextTokenChooserParameters { /// exponential scaling output probability distribution float temperature = 1; @@ -70,6 +75,9 @@ message NextTokenChooserParameters { bool watermark = 8; bool use_grammar_constraint = 9; string grammar = 10; + repeated LogitBias logit_bias = 11; + float guidance_scale = 12; + string negative_inputs = 13; } message StoppingCriteriaParameters { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 22cf6c11..7431251f 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -128,6 +128,10 @@ impl Client { watermark: true, use_grammar_constraint: false, grammar: "".to_string(), + logit_bias: Vec::new(), + guidance_scale: 1.0, + negative_inputs: "".to_string(), + }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: max_total_tokens - truncate, diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index c38b931b..3d11b917 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -10,7 +10,7 @@ pub use pb::generate::v2::HealthResponse; pub use pb::generate::v2::InfoResponse as ShardInfo; pub use pb::generate::v2::{ Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, - Request, StoppingCriteriaParameters, Tokens, + Request, StoppingCriteriaParameters, Tokens,LogitBias }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/src/health.rs b/router/src/health.rs index 78c94017..a6755202 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -46,6 +46,9 @@ impl Health { watermark: false, use_grammar_constraint: false, grammar: "".to_string(), + logit_bias: Vec::new(), + guidance_scale: 1.0, + negative_inputs: "".to_string(), }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/router/src/lib.rs b/router/src/lib.rs index 8e19605c..f481465f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -122,12 +122,21 @@ pub(crate) struct GenerateParameters { #[schema(default = "false", example = true)] pub watermark: bool, #[serde(default)] - #[schema(default = "true")] + #[schema(default = "false")] pub use_grammar_constraint: bool, #[serde(default)] - #[schema(default = "false")] + #[schema(default = "")] pub grammar: String, #[serde(default)] + #[schema(default = "null", nullable = true)] + pub logit_bias: Vec<(String, f32)>, + #[serde(default)] + #[schema(default = "1.0")] + pub guidance_scale: f32, + #[serde(default)] + #[schema(default = "")] + pub negative_inputs: String, + #[serde(default)] #[schema(default = "")] pub details: bool, #[serde(default)] @@ -166,6 +175,9 @@ fn default_parameters() -> GenerateParameters { watermark: false, use_grammar_constraint: false, grammar: "".to_string(), + logit_bias: Vec::new(), + guidance_scale: 1.0, + negative_inputs: "".to_string(), details: false, decoder_input_details: false, seed: None, diff --git a/router/src/queue.rs b/router/src/queue.rs index 106cacc4..f359db2d 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -356,6 +356,11 @@ mod tests { seed: 0, repetition_penalty: 0.0, watermark: false, + use_grammar_constraint: false, + grammar: "".to_string(), + logit_bias: Vec::new(), + guidance_scale: 1.0, + negative_inputs: "".to_string(), }, stopping_parameters: StoppingCriteriaParameters { ignore_eos_token: false, diff --git a/router/src/validation.rs b/router/src/validation.rs index f5a28b94..43112afa 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -2,7 +2,7 @@ use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{GenerateParameters, GenerateRequest}; use rand::{thread_rng, Rng}; -use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; +use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters, LogitBias}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokenizers::TruncationDirection; @@ -165,6 +165,9 @@ impl Validation { watermark, use_grammar_constraint, grammar, + guidance_scale, + negative_inputs, + logit_bias, decoder_input_details, top_n_tokens, .. @@ -268,10 +271,19 @@ impl Validation { .unwrap_or(Ok(None))?; // Validate inputs - let (inputs, input_length, max_new_tokens) = self + let (inputs, _input_length, max_new_tokens) = self .validate_input(request.inputs, truncate, max_new_tokens) .await?; + let (negative_inputs, input_length, max_new_tokens) = self + .validate_input(negative_inputs, truncate, Some(max_new_tokens)) + .await?; + + let logit_biases: Vec = logit_bias + .into_iter() + .map(|(word, bias)| LogitBias { word, bias }) + .collect(); + let parameters = NextTokenChooserParameters { temperature, repetition_penalty, @@ -283,6 +295,9 @@ impl Validation { watermark, use_grammar_constraint, grammar, + logit_bias: logit_biases, + guidance_scale, + negative_inputs, }; let stopping_parameters = StoppingCriteriaParameters { max_new_tokens, diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index fed5e6f3..b28575fa 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -7,6 +7,7 @@ from transformers import ( AutoTokenizer, AutoConfig, PreTrainedTokenizerBase, + PreTrainedModel, ) from text_generation_server.models.custom_modeling.bloom_modeling import ( @@ -28,10 +29,11 @@ class BloomCausalLMBatch(CausalLMBatch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + model: PreTrainedModel, dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": - batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) + batch = super().from_pb(pb=pb, tokenizer=tokenizer, model=model, dtype=dtype, device=device) batch.keys_head_dim_last = False return batch @@ -101,7 +103,7 @@ class BLOOMSharded(CausalLM): return BloomCausalLMBatch def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None #type: ignore ): outputs = self.model.forward( input_ids=input_ids, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 97b04b60..56284061 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -3,7 +3,7 @@ import time from dataclasses import dataclass from opentelemetry import trace -from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase +from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, PreTrainedModel from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model @@ -69,6 +69,7 @@ class CausalLMBatch(Batch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + model: PreTrainedModel, dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": @@ -87,9 +88,9 @@ class CausalLMBatch(Batch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model)) stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer + r.stopping_parameters, tokenizer, model ) stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) @@ -258,7 +259,7 @@ class CausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "CausalLMBatch": + def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "CausalLMBatch": # Used for padding total_batch_size = 0 max_input_length = 0 @@ -545,7 +546,7 @@ class CausalLM(Model): ) def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None # type: ignore ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward kwargs = { diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index cdf3803b..126e76d9 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -8,7 +8,7 @@ import numpy as np from dataclasses import dataclass from opentelemetry import trace -from transformers import PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase, PreTrainedModel from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model @@ -108,6 +108,7 @@ class FlashCausalLMBatch(Batch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + model: PreTrainedModel, dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": @@ -181,7 +182,7 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters.append(r.parameters) stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer + r.stopping_parameters, tokenizer, model ) max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) @@ -233,7 +234,7 @@ class FlashCausalLMBatch(Batch): ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device, tokenizer + next_token_chooser_parameters, dtype, device, tokenizer, model ) start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -468,7 +469,7 @@ class FlashCausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "FlashCausalLMBatch": + def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "FlashCausalLMBatch": # Batch attributes requests = [] requests_idx_mapping = {} @@ -590,6 +591,7 @@ class FlashCausalLMBatch(Batch): dtype=batches[0].next_token_chooser.dtype, device=batches[0].next_token_chooser.device, tokenizer=tokenizer, + model=model, ) speculative_ids = ( diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 18b98c66..5967cbf6 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -6,7 +6,7 @@ import numpy as np from dataclasses import dataclass from opentelemetry import trace -from transformers import PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase, PreTrainedModel from transformers.models.llama import LlamaTokenizerFast from typing import Optional, Tuple, Type, List @@ -48,6 +48,7 @@ class FlashMistralBatch(FlashCausalLMBatch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + model: PreTrainedModel, dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": @@ -124,7 +125,7 @@ class FlashMistralBatch(FlashCausalLMBatch): next_token_chooser_parameters.append(r.parameters) stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer + r.stopping_parameters, tokenizer, model ) max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) @@ -190,7 +191,7 @@ class FlashMistralBatch(FlashCausalLMBatch): ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device, tokenizer + next_token_chooser_parameters, dtype, device, tokenizer, model ) start_slots = torch.tensor(start_slots, dtype=torch.int64) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index a2c30255..4e48d299 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -8,6 +8,7 @@ from transformers import ( AutoTokenizer, AutoConfig, PreTrainedTokenizerBase, + PreTrainedModel, ) from text_generation_server.models import CausalLM from text_generation_server.models.causal_lm import CausalLMBatch @@ -73,6 +74,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + model: PreTrainedModel, dtype: torch.dtype, device: torch.device, ) -> "GalacticaCausalLMBatch": @@ -92,9 +94,9 @@ class GalacticaCausalLMBatch(CausalLMBatch): requests_idx_mapping[r.id] = i # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model)) stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer + r.stopping_parameters, tokenizer, model ) stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) @@ -225,7 +227,7 @@ class GalacticaSharded(CausalLM): ) def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None #type: ignore ): outputs = self.model.forward( input_ids=input_ids, diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 1f633f8a..37fa397b 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -7,6 +7,7 @@ from transformers import ( AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase, + PreTrainedModel, ProcessorMixin, ) from typing import Optional, Tuple, List, Type, Dict @@ -96,6 +97,7 @@ class IdeficsCausalLMBatch(Batch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + model: PreTrainedModel, processor: ProcessorMixin, # Hack dtype: torch.dtype, device: torch.device, @@ -114,9 +116,9 @@ class IdeficsCausalLMBatch(Batch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model)) stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer + r.stopping_parameters, tokenizer, model ) stopping_criterias.append(stopping_criteria) max_truncation = max(max_truncation, r.truncate) @@ -642,7 +644,7 @@ class IdeficsCausalLM(Model): pixel_values, image_hidden_states, image_attention_mask, - past_key_values: Optional = None, + past_key_values: Optional = None, #type: ignore ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward kwargs = { diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index e419467f..c858ccdd 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -4,7 +4,7 @@ import torch.distributed from pathlib import Path from typing import Optional, Type from opentelemetry import trace -from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase +from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel from huggingface_hub import hf_hub_download import json @@ -29,10 +29,11 @@ class MPTCausalLMBatch(CausalLMBatch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + model: PreTrainedModel, dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": - batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) + batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device, model=model) batch.keys_head_dim_last = False return batch diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index d506a46a..8c3c98c0 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -3,7 +3,7 @@ import time from dataclasses import dataclass from opentelemetry import trace -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase, PreTrainedModel from typing import Optional, Tuple, List, Type, Dict from text_generation_server.utils.tokens import batch_top_tokens @@ -75,6 +75,7 @@ class Seq2SeqLMBatch(Batch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + model: PreTrainedModel, dtype: torch.dtype, device: torch.device, ) -> "Seq2SeqLMBatch": @@ -96,9 +97,9 @@ class Seq2SeqLMBatch(Batch): inputs.append(r.inputs) requests_idx_mapping[r.id] = i decoder_input_lengths.append(1) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model)) stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer + r.stopping_parameters, tokenizer, model ) stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) @@ -278,7 +279,7 @@ class Seq2SeqLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": + def concatenate(cls, batches: List["Seq2SeqLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "Seq2SeqLMBatch": """Concatenate multiple batches together by padding internal torch tensors""" # Used for padding @@ -587,9 +588,9 @@ class Seq2SeqLM(Model): input_ids, attention_mask, decoder_input_ids, - decoder_attention_mask: Optional, - encoder_last_hidden_state: Optional, - past_key_values: Optional = None, + decoder_attention_mask: Optional, #type: ignore + encoder_last_hidden_state: Optional, #type: ignore + past_key_values: Optional = None, #type: ignore ) -> Tuple[ torch.Tensor, torch.Tensor, diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index f85f27e5..8891d2b0 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import List, Optional -from transformers import PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase, PreTrainedModel from text_generation_server.pb import generate_pb2 from text_generation_server.pb.generate_pb2 import FinishReason @@ -22,6 +22,7 @@ class Batch(ABC): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + model: PreTrainedModel, dtype: torch.dtype, device: torch.device, ) -> "Batch": @@ -33,7 +34,7 @@ class Batch(ABC): @classmethod @abstractmethod - def concatenate(cls, batches: List["Batch"]) -> "Batch": + def concatenate(cls, batches: List["Batch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "Batch": raise NotImplementedError @abstractmethod diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 847ec583..878817a9 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -85,12 +85,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): request.batch, self.model.tokenizer, self.model.processor, + self.model, self.model.dtype, self.model.device, ) else: batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.dtype, self.model.device + request.batch, self.model.tokenizer, self.model, self.model.dtype, self.model.device ) max_supported_total_tokens = self.model.warmup(batch) @@ -107,12 +108,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): request.batch, self.model.tokenizer, self.model.processor, + self.model, self.model.dtype, self.model.device, ) else: batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.dtype, self.model.device + request.batch, self.model.tokenizer, self.model, self.model.dtype, self.model.device ) generations, next_batch, timings = self.model.generate_token(batch) @@ -143,7 +145,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if len(batches) > 1: start_concat = time.time_ns() - batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer) + batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer, model=self.model) concat_ns = time.time_ns() - start_concat else: batch = batches[0] diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 998e1df3..68563cb5 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -14,7 +14,7 @@ from text_generation_server.utils.logits_process import ( static_warper, ) from text_generation_server.utils.watermark import WatermarkLogitsProcessor -from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor +from transformers import PreTrainedTokenizerBase,RepetitionPenaltyLogitsProcessor,UnbatchedClassifierFreeGuidanceLogitsProcessor,PreTrainedModel from transformers_cfg.grammar_utils import IncrementalGrammarConstraint from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor @@ -32,8 +32,11 @@ class NextTokenChooser: seed=0, device="cpu", tokenizer=None, + model=None, use_grammar_constraint=False, grammar="", + guidance_scale=1.0, + negative_inputs="", ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -50,6 +53,19 @@ class NextTokenChooser: else: self.grammar_processor = None + if guidance_scale != 1.0 and model is not None and negative_inputs: + negative_inputs_t = tokenizer([negative_inputs], return_tensors="pt") + device = next(model.model.parameters()).device + self.guidance_scale_processor = UnbatchedClassifierFreeGuidanceLogitsProcessor( + guidance_scale, + model.model, + unconditional_ids=negative_inputs_t["input_ids"].to(device), + unconditional_attention_mask=negative_inputs_t["attention_mask"].to(device), + use_cache=True, # use cache permanently on for now. + ) + else: + self.guidance_scale_processor = None + has_warpers = ( (temperature is not None and temperature != 1.0) or (top_k is not None and top_k != 0) @@ -67,12 +83,15 @@ class NextTokenChooser: self.choice = Sampling(seed, device) if sampling else Greedy() def __call__(self, input_ids, scores): + if self.guidance_scale_processor is not None: + scores = self.guidance_scale_processor(input_ids, scores) + if self.grammar_processor is not None: + scores = self.grammar_processor(input_ids, scores) if self.watermark_processor is not None: scores = self.watermark_processor(input_ids, scores) if self.repetition_processor is not None: scores = self.repetition_processor(input_ids, scores) - if self.grammar_processor is not None: - scores = self.grammar_processor(input_ids, scores) + if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) @@ -89,6 +108,7 @@ class NextTokenChooser: pb: generate_pb2.NextTokenChooserParameters, device: torch.device, tokenizer: PreTrainedTokenizerBase, + model: PreTrainedModel, ) -> "NextTokenChooser": return NextTokenChooser( watermark=pb.watermark, @@ -103,6 +123,9 @@ class NextTokenChooser: use_grammar_constraint=pb.use_grammar_constraint, grammar=pb.grammar, tokenizer=tokenizer, + model=model, + guidance_scale=pb.guidance_scale, + negative_inputs=pb.negative_inputs, ) @@ -157,6 +180,7 @@ class StoppingCriteria: cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: PreTrainedTokenizerBase, + model: PreTrainedModel, ) -> "StoppingCriteria": stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences @@ -199,9 +223,12 @@ class HeterogeneousNextTokenChooser: dtype: torch.dtype, device: torch.device, tokenizer: PreTrainedTokenizerBase, + model: PreTrainedModel, watermark: List[bool], use_grammar_constraint: List[bool], grammar: List[str], + guidance_scale: List[float], + negative_inputs: List[str], temperature: List[float], repetition_penalty: List[float], top_k: List[int], @@ -232,10 +259,14 @@ class HeterogeneousNextTokenChooser: else None ) - if use_grammar_constraint: - grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer) - grammar_processor = GrammarConstrainedLogitsProcessor(grammar) - warpers.append(grammar_processor) + if any(use_grammar_constraint): + grammar_processors = { + i: GrammarConstrainedLogitsProcessor(IncrementalGrammarConstraint(grammar[i], "root", tokenizer)) + for i, use_gc in enumerate(use_grammar_constraint) if use_gc + } + self.grammar_processor = HeterogeneousProcessorWrapper(grammar_processors) + else: + self.grammar_processor = None if any([x != 1.0 for x in temperature]): do_sample = [ @@ -294,6 +325,8 @@ class HeterogeneousNextTokenChooser: _scores = self.watermark_processor(input_ids, _scores) if self.repetition_processor is not None: _scores = self.repetition_processor(input_ids, _scores) + if self.grammar_processor is not None: + _scores = self.grammar_processor(input_ids, _scores) for warper in self.warpers: _scores = warper(input_ids, _scores) @@ -385,6 +418,7 @@ class HeterogeneousNextTokenChooser: dtype: torch.dtype, device: torch.device, tokenizer: PreTrainedTokenizerBase, + model: PreTrainedModel, ) -> "HeterogeneousNextTokenChooser": return HeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], @@ -398,8 +432,11 @@ class HeterogeneousNextTokenChooser: device=device, dtype=dtype, tokenizer=tokenizer, - use_grammar_constraint=use_grammar_constraint, - grammar=grammar, + model=model, + use_grammar_constraint=[pb_.use_grammar_constraint for pb_ in pb], + grammar=[pb_.grammar for pb_ in pb], + guidance_scale=[pb_.guidance_scale for pb_ in pb], + negative_inputs=[pb_.negative_inputs for pb_ in pb], )