adding guidance and extra parameters for token bias

This commit is contained in:
Łukasz Olszewski 2023-12-22 11:14:06 +01:00
parent 1108560745
commit d84b38e30d
18 changed files with 147 additions and 48 deletions

View File

@ -51,6 +51,11 @@ message ClearCacheRequest {
/// Empty response /// Empty response
message ClearCacheResponse {} message ClearCacheResponse {}
message LogitBias {
string word = 1;
float bias = 2;
}
message NextTokenChooserParameters { message NextTokenChooserParameters {
/// exponential scaling output probability distribution /// exponential scaling output probability distribution
float temperature = 1; float temperature = 1;
@ -70,6 +75,9 @@ message NextTokenChooserParameters {
bool watermark = 8; bool watermark = 8;
bool use_grammar_constraint = 9; bool use_grammar_constraint = 9;
string grammar = 10; string grammar = 10;
repeated LogitBias logit_bias = 11;
float guidance_scale = 12;
string negative_inputs = 13;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {

View File

@ -128,6 +128,10 @@ impl Client {
watermark: true, watermark: true,
use_grammar_constraint: false, use_grammar_constraint: false,
grammar: "".to_string(), grammar: "".to_string(),
logit_bias: Vec::new(),
guidance_scale: 1.0,
negative_inputs: "".to_string(),
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate, max_new_tokens: max_total_tokens - truncate,

View File

@ -10,7 +10,7 @@ pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo; pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{ pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
Request, StoppingCriteriaParameters, Tokens, Request, StoppingCriteriaParameters, Tokens,LogitBias
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;

View File

@ -46,6 +46,9 @@ impl Health {
watermark: false, watermark: false,
use_grammar_constraint: false, use_grammar_constraint: false,
grammar: "".to_string(), grammar: "".to_string(),
logit_bias: Vec::new(),
guidance_scale: 1.0,
negative_inputs: "".to_string(),
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1, max_new_tokens: 1,

View File

@ -122,12 +122,21 @@ pub(crate) struct GenerateParameters {
#[schema(default = "false", example = true)] #[schema(default = "false", example = true)]
pub watermark: bool, pub watermark: bool,
#[serde(default)] #[serde(default)]
#[schema(default = "true")] #[schema(default = "false")]
pub use_grammar_constraint: bool, pub use_grammar_constraint: bool,
#[serde(default)] #[serde(default)]
#[schema(default = "false")] #[schema(default = "")]
pub grammar: String, pub grammar: String,
#[serde(default)] #[serde(default)]
#[schema(default = "null", nullable = true)]
pub logit_bias: Vec<(String, f32)>,
#[serde(default)]
#[schema(default = "1.0")]
pub guidance_scale: f32,
#[serde(default)]
#[schema(default = "")]
pub negative_inputs: String,
#[serde(default)]
#[schema(default = "")] #[schema(default = "")]
pub details: bool, pub details: bool,
#[serde(default)] #[serde(default)]
@ -166,6 +175,9 @@ fn default_parameters() -> GenerateParameters {
watermark: false, watermark: false,
use_grammar_constraint: false, use_grammar_constraint: false,
grammar: "".to_string(), grammar: "".to_string(),
logit_bias: Vec::new(),
guidance_scale: 1.0,
negative_inputs: "".to_string(),
details: false, details: false,
decoder_input_details: false, decoder_input_details: false,
seed: None, seed: None,

View File

@ -356,6 +356,11 @@ mod tests {
seed: 0, seed: 0,
repetition_penalty: 0.0, repetition_penalty: 0.0,
watermark: false, watermark: false,
use_grammar_constraint: false,
grammar: "".to_string(),
logit_bias: Vec::new(),
guidance_scale: 1.0,
negative_inputs: "".to_string(),
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false, ignore_eos_token: false,

View File

@ -2,7 +2,7 @@
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters, LogitBias};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokenizers::TruncationDirection; use tokenizers::TruncationDirection;
@ -165,6 +165,9 @@ impl Validation {
watermark, watermark,
use_grammar_constraint, use_grammar_constraint,
grammar, grammar,
guidance_scale,
negative_inputs,
logit_bias,
decoder_input_details, decoder_input_details,
top_n_tokens, top_n_tokens,
.. ..
@ -268,10 +271,19 @@ impl Validation {
.unwrap_or(Ok(None))?; .unwrap_or(Ok(None))?;
// Validate inputs // Validate inputs
let (inputs, input_length, max_new_tokens) = self let (inputs, _input_length, max_new_tokens) = self
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(request.inputs, truncate, max_new_tokens)
.await?; .await?;
let (negative_inputs, input_length, max_new_tokens) = self
.validate_input(negative_inputs, truncate, Some(max_new_tokens))
.await?;
let logit_biases: Vec<LogitBias> = logit_bias
.into_iter()
.map(|(word, bias)| LogitBias { word, bias })
.collect();
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
repetition_penalty, repetition_penalty,
@ -283,6 +295,9 @@ impl Validation {
watermark, watermark,
use_grammar_constraint, use_grammar_constraint,
grammar, grammar,
logit_bias: logit_biases,
guidance_scale,
negative_inputs,
}; };
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,

View File

@ -7,6 +7,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
AutoConfig, AutoConfig,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
PreTrainedModel,
) )
from text_generation_server.models.custom_modeling.bloom_modeling import ( from text_generation_server.models.custom_modeling.bloom_modeling import (
@ -28,10 +29,11 @@ class BloomCausalLMBatch(CausalLMBatch):
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) batch = super().from_pb(pb=pb, tokenizer=tokenizer, model=model, dtype=dtype, device=device)
batch.keys_head_dim_last = False batch.keys_head_dim_last = False
return batch return batch
@ -101,7 +103,7 @@ class BLOOMSharded(CausalLM):
return BloomCausalLMBatch return BloomCausalLMBatch
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None #type: ignore
): ):
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,

View File

@ -3,7 +3,7 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, PreTrainedModel
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
@ -69,6 +69,7 @@ class CausalLMBatch(Batch):
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
@ -87,9 +88,9 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer, model
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens) top_n_tokens.append(r.top_n_tokens)
@ -258,7 +259,7 @@ class CausalLMBatch(Batch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "CausalLMBatch": def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "CausalLMBatch":
# Used for padding # Used for padding
total_batch_size = 0 total_batch_size = 0
max_input_length = 0 max_input_length = 0
@ -545,7 +546,7 @@ class CausalLM(Model):
) )
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None # type: ignore
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward # Model Forward
kwargs = { kwargs = {

View File

@ -8,7 +8,7 @@ import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase, PreTrainedModel
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
@ -108,6 +108,7 @@ class FlashCausalLMBatch(Batch):
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
@ -181,7 +182,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters.append(r.parameters) next_token_chooser_parameters.append(r.parameters)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer, model
) )
max_new_tokens = stopping_criteria.max_new_tokens max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
@ -233,7 +234,7 @@ class FlashCausalLMBatch(Batch):
) )
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer next_token_chooser_parameters, dtype, device, tokenizer, model
) )
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -468,7 +469,7 @@ class FlashCausalLMBatch(Batch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "FlashCausalLMBatch": def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "FlashCausalLMBatch":
# Batch attributes # Batch attributes
requests = [] requests = []
requests_idx_mapping = {} requests_idx_mapping = {}
@ -590,6 +591,7 @@ class FlashCausalLMBatch(Batch):
dtype=batches[0].next_token_chooser.dtype, dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device, device=batches[0].next_token_chooser.device,
tokenizer=tokenizer, tokenizer=tokenizer,
model=model,
) )
speculative_ids = ( speculative_ids = (

View File

@ -6,7 +6,7 @@ import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase, PreTrainedModel
from transformers.models.llama import LlamaTokenizerFast from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type, List from typing import Optional, Tuple, Type, List
@ -48,6 +48,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
@ -124,7 +125,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
next_token_chooser_parameters.append(r.parameters) next_token_chooser_parameters.append(r.parameters)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer, model
) )
max_new_tokens = stopping_criteria.max_new_tokens max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
@ -190,7 +191,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
) )
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer next_token_chooser_parameters, dtype, device, tokenizer, model
) )
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)

View File

@ -8,6 +8,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
AutoConfig, AutoConfig,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
PreTrainedModel,
) )
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
@ -73,6 +74,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "GalacticaCausalLMBatch": ) -> "GalacticaCausalLMBatch":
@ -92,9 +94,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(escape_custom_split_sequence(r.inputs))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer, model
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens) top_n_tokens.append(r.top_n_tokens)
@ -225,7 +227,7 @@ class GalacticaSharded(CausalLM):
) )
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None #type: ignore
): ):
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,

View File

@ -7,6 +7,7 @@ from transformers import (
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
PreTrainedModel,
ProcessorMixin, ProcessorMixin,
) )
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
@ -96,6 +97,7 @@ class IdeficsCausalLMBatch(Batch):
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
processor: ProcessorMixin, # Hack processor: ProcessorMixin, # Hack
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
@ -114,9 +116,9 @@ class IdeficsCausalLMBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer, model
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
@ -642,7 +644,7 @@ class IdeficsCausalLM(Model):
pixel_values, pixel_values,
image_hidden_states, image_hidden_states,
image_attention_mask, image_attention_mask,
past_key_values: Optional = None, past_key_values: Optional = None, #type: ignore
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward # Model Forward
kwargs = { kwargs = {

View File

@ -4,7 +4,7 @@ import torch.distributed
from pathlib import Path from pathlib import Path
from typing import Optional, Type from typing import Optional, Type
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
@ -29,10 +29,11 @@ class MPTCausalLMBatch(CausalLMBatch):
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device, model=model)
batch.keys_head_dim_last = False batch.keys_head_dim_last = False
return batch return batch

View File

@ -3,7 +3,7 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase, PreTrainedModel
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
@ -75,6 +75,7 @@ class Seq2SeqLMBatch(Batch):
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "Seq2SeqLMBatch": ) -> "Seq2SeqLMBatch":
@ -96,9 +97,9 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs) inputs.append(r.inputs)
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer, model
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens) top_n_tokens.append(r.top_n_tokens)
@ -278,7 +279,7 @@ class Seq2SeqLMBatch(Batch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": def concatenate(cls, batches: List["Seq2SeqLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "Seq2SeqLMBatch":
"""Concatenate multiple batches together by padding internal torch tensors""" """Concatenate multiple batches together by padding internal torch tensors"""
# Used for padding # Used for padding
@ -587,9 +588,9 @@ class Seq2SeqLM(Model):
input_ids, input_ids,
attention_mask, attention_mask,
decoder_input_ids, decoder_input_ids,
decoder_attention_mask: Optional, decoder_attention_mask: Optional, #type: ignore
encoder_last_hidden_state: Optional, encoder_last_hidden_state: Optional, #type: ignore
past_key_values: Optional = None, past_key_values: Optional = None, #type: ignore
) -> Tuple[ ) -> Tuple[
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import List, Optional
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase, PreTrainedModel
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
@ -22,6 +22,7 @@ class Batch(ABC):
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "Batch": ) -> "Batch":
@ -33,7 +34,7 @@ class Batch(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def concatenate(cls, batches: List["Batch"]) -> "Batch": def concatenate(cls, batches: List["Batch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "Batch":
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod

View File

@ -85,12 +85,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, request.batch,
self.model.tokenizer, self.model.tokenizer,
self.model.processor, self.model.processor,
self.model,
self.model.dtype, self.model.dtype,
self.model.device, self.model.device,
) )
else: else:
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model, self.model.dtype, self.model.device
) )
max_supported_total_tokens = self.model.warmup(batch) max_supported_total_tokens = self.model.warmup(batch)
@ -107,12 +108,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, request.batch,
self.model.tokenizer, self.model.tokenizer,
self.model.processor, self.model.processor,
self.model,
self.model.dtype, self.model.dtype,
self.model.device, self.model.device,
) )
else: else:
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model, self.model.dtype, self.model.device
) )
generations, next_batch, timings = self.model.generate_token(batch) generations, next_batch, timings = self.model.generate_token(batch)
@ -143,7 +145,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) > 1: if len(batches) > 1:
start_concat = time.time_ns() start_concat = time.time_ns()
batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer) batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer, model=self.model)
concat_ns = time.time_ns() - start_concat concat_ns = time.time_ns() - start_concat
else: else:
batch = batches[0] batch = batches[0]

View File

@ -14,7 +14,7 @@ from text_generation_server.utils.logits_process import (
static_warper, static_warper,
) )
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor from transformers import PreTrainedTokenizerBase,RepetitionPenaltyLogitsProcessor,UnbatchedClassifierFreeGuidanceLogitsProcessor,PreTrainedModel
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
@ -32,8 +32,11 @@ class NextTokenChooser:
seed=0, seed=0,
device="cpu", device="cpu",
tokenizer=None, tokenizer=None,
model=None,
use_grammar_constraint=False, use_grammar_constraint=False,
grammar="", grammar="",
guidance_scale=1.0,
negative_inputs="",
): ):
self.watermark_processor = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
@ -50,6 +53,19 @@ class NextTokenChooser:
else: else:
self.grammar_processor = None self.grammar_processor = None
if guidance_scale != 1.0 and model is not None and negative_inputs:
negative_inputs_t = tokenizer([negative_inputs], return_tensors="pt")
device = next(model.model.parameters()).device
self.guidance_scale_processor = UnbatchedClassifierFreeGuidanceLogitsProcessor(
guidance_scale,
model.model,
unconditional_ids=negative_inputs_t["input_ids"].to(device),
unconditional_attention_mask=negative_inputs_t["attention_mask"].to(device),
use_cache=True, # use cache permanently on for now.
)
else:
self.guidance_scale_processor = None
has_warpers = ( has_warpers = (
(temperature is not None and temperature != 1.0) (temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0) or (top_k is not None and top_k != 0)
@ -67,12 +83,15 @@ class NextTokenChooser:
self.choice = Sampling(seed, device) if sampling else Greedy() self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
if self.guidance_scale_processor is not None:
scores = self.guidance_scale_processor(input_ids, scores)
if self.grammar_processor is not None:
scores = self.grammar_processor(input_ids, scores)
if self.watermark_processor is not None: if self.watermark_processor is not None:
scores = self.watermark_processor(input_ids, scores) scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None: if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores) scores = self.repetition_processor(input_ids, scores)
if self.grammar_processor is not None:
scores = self.grammar_processor(input_ids, scores)
if self.static_warper is None: if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1) next_logprob = torch.log_softmax(scores, -1)
@ -89,6 +108,7 @@ class NextTokenChooser:
pb: generate_pb2.NextTokenChooserParameters, pb: generate_pb2.NextTokenChooserParameters,
device: torch.device, device: torch.device,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
) -> "NextTokenChooser": ) -> "NextTokenChooser":
return NextTokenChooser( return NextTokenChooser(
watermark=pb.watermark, watermark=pb.watermark,
@ -103,6 +123,9 @@ class NextTokenChooser:
use_grammar_constraint=pb.use_grammar_constraint, use_grammar_constraint=pb.use_grammar_constraint,
grammar=pb.grammar, grammar=pb.grammar,
tokenizer=tokenizer, tokenizer=tokenizer,
model=model,
guidance_scale=pb.guidance_scale,
negative_inputs=pb.negative_inputs,
) )
@ -157,6 +180,7 @@ class StoppingCriteria:
cls, cls,
pb: generate_pb2.StoppingCriteriaParameters, pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
) -> "StoppingCriteria": ) -> "StoppingCriteria":
stop_sequence_criterias = [ stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
@ -199,9 +223,12 @@ class HeterogeneousNextTokenChooser:
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
watermark: List[bool], watermark: List[bool],
use_grammar_constraint: List[bool], use_grammar_constraint: List[bool],
grammar: List[str], grammar: List[str],
guidance_scale: List[float],
negative_inputs: List[str],
temperature: List[float], temperature: List[float],
repetition_penalty: List[float], repetition_penalty: List[float],
top_k: List[int], top_k: List[int],
@ -232,10 +259,14 @@ class HeterogeneousNextTokenChooser:
else None else None
) )
if use_grammar_constraint: if any(use_grammar_constraint):
grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer) grammar_processors = {
grammar_processor = GrammarConstrainedLogitsProcessor(grammar) i: GrammarConstrainedLogitsProcessor(IncrementalGrammarConstraint(grammar[i], "root", tokenizer))
warpers.append(grammar_processor) for i, use_gc in enumerate(use_grammar_constraint) if use_gc
}
self.grammar_processor = HeterogeneousProcessorWrapper(grammar_processors)
else:
self.grammar_processor = None
if any([x != 1.0 for x in temperature]): if any([x != 1.0 for x in temperature]):
do_sample = [ do_sample = [
@ -294,6 +325,8 @@ class HeterogeneousNextTokenChooser:
_scores = self.watermark_processor(input_ids, _scores) _scores = self.watermark_processor(input_ids, _scores)
if self.repetition_processor is not None: if self.repetition_processor is not None:
_scores = self.repetition_processor(input_ids, _scores) _scores = self.repetition_processor(input_ids, _scores)
if self.grammar_processor is not None:
_scores = self.grammar_processor(input_ids, _scores)
for warper in self.warpers: for warper in self.warpers:
_scores = warper(input_ids, _scores) _scores = warper(input_ids, _scores)
@ -385,6 +418,7 @@ class HeterogeneousNextTokenChooser:
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
model: PreTrainedModel,
) -> "HeterogeneousNextTokenChooser": ) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser( return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb], watermark=[pb_.watermark for pb_ in pb],
@ -398,8 +432,11 @@ class HeterogeneousNextTokenChooser:
device=device, device=device,
dtype=dtype, dtype=dtype,
tokenizer=tokenizer, tokenizer=tokenizer,
use_grammar_constraint=use_grammar_constraint, model=model,
grammar=grammar, use_grammar_constraint=[pb_.use_grammar_constraint for pb_ in pb],
grammar=[pb_.grammar for pb_ in pb],
guidance_scale=[pb_.guidance_scale for pb_ in pb],
negative_inputs=[pb_.negative_inputs for pb_ in pb],
) )