mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
adding guidance and extra parameters for token bias
This commit is contained in:
parent
1108560745
commit
d84b38e30d
@ -51,6 +51,11 @@ message ClearCacheRequest {
|
||||
/// Empty response
|
||||
message ClearCacheResponse {}
|
||||
|
||||
message LogitBias {
|
||||
string word = 1;
|
||||
float bias = 2;
|
||||
}
|
||||
|
||||
message NextTokenChooserParameters {
|
||||
/// exponential scaling output probability distribution
|
||||
float temperature = 1;
|
||||
@ -70,6 +75,9 @@ message NextTokenChooserParameters {
|
||||
bool watermark = 8;
|
||||
bool use_grammar_constraint = 9;
|
||||
string grammar = 10;
|
||||
repeated LogitBias logit_bias = 11;
|
||||
float guidance_scale = 12;
|
||||
string negative_inputs = 13;
|
||||
}
|
||||
|
||||
message StoppingCriteriaParameters {
|
||||
|
@ -128,6 +128,10 @@ impl Client {
|
||||
watermark: true,
|
||||
use_grammar_constraint: false,
|
||||
grammar: "".to_string(),
|
||||
logit_bias: Vec::new(),
|
||||
guidance_scale: 1.0,
|
||||
negative_inputs: "".to_string(),
|
||||
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: max_total_tokens - truncate,
|
||||
|
@ -10,7 +10,7 @@ pub use pb::generate::v2::HealthResponse;
|
||||
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
||||
pub use pb::generate::v2::{
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
||||
Request, StoppingCriteriaParameters, Tokens,
|
||||
Request, StoppingCriteriaParameters, Tokens,LogitBias
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
use thiserror::Error;
|
||||
|
@ -46,6 +46,9 @@ impl Health {
|
||||
watermark: false,
|
||||
use_grammar_constraint: false,
|
||||
grammar: "".to_string(),
|
||||
logit_bias: Vec::new(),
|
||||
guidance_scale: 1.0,
|
||||
negative_inputs: "".to_string(),
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
|
@ -122,12 +122,21 @@ pub(crate) struct GenerateParameters {
|
||||
#[schema(default = "false", example = true)]
|
||||
pub watermark: bool,
|
||||
#[serde(default)]
|
||||
#[schema(default = "true")]
|
||||
#[schema(default = "false")]
|
||||
pub use_grammar_constraint: bool,
|
||||
#[serde(default)]
|
||||
#[schema(default = "false")]
|
||||
#[schema(default = "")]
|
||||
pub grammar: String,
|
||||
#[serde(default)]
|
||||
#[schema(default = "null", nullable = true)]
|
||||
pub logit_bias: Vec<(String, f32)>,
|
||||
#[serde(default)]
|
||||
#[schema(default = "1.0")]
|
||||
pub guidance_scale: f32,
|
||||
#[serde(default)]
|
||||
#[schema(default = "")]
|
||||
pub negative_inputs: String,
|
||||
#[serde(default)]
|
||||
#[schema(default = "")]
|
||||
pub details: bool,
|
||||
#[serde(default)]
|
||||
@ -166,6 +175,9 @@ fn default_parameters() -> GenerateParameters {
|
||||
watermark: false,
|
||||
use_grammar_constraint: false,
|
||||
grammar: "".to_string(),
|
||||
logit_bias: Vec::new(),
|
||||
guidance_scale: 1.0,
|
||||
negative_inputs: "".to_string(),
|
||||
details: false,
|
||||
decoder_input_details: false,
|
||||
seed: None,
|
||||
|
@ -356,6 +356,11 @@ mod tests {
|
||||
seed: 0,
|
||||
repetition_penalty: 0.0,
|
||||
watermark: false,
|
||||
use_grammar_constraint: false,
|
||||
grammar: "".to_string(),
|
||||
logit_bias: Vec::new(),
|
||||
guidance_scale: 1.0,
|
||||
negative_inputs: "".to_string(),
|
||||
},
|
||||
stopping_parameters: StoppingCriteriaParameters {
|
||||
ignore_eos_token: false,
|
||||
|
@ -2,7 +2,7 @@
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use rand::{thread_rng, Rng};
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters, LogitBias};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokenizers::TruncationDirection;
|
||||
@ -165,6 +165,9 @@ impl Validation {
|
||||
watermark,
|
||||
use_grammar_constraint,
|
||||
grammar,
|
||||
guidance_scale,
|
||||
negative_inputs,
|
||||
logit_bias,
|
||||
decoder_input_details,
|
||||
top_n_tokens,
|
||||
..
|
||||
@ -268,10 +271,19 @@ impl Validation {
|
||||
.unwrap_or(Ok(None))?;
|
||||
|
||||
// Validate inputs
|
||||
let (inputs, input_length, max_new_tokens) = self
|
||||
let (inputs, _input_length, max_new_tokens) = self
|
||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||
.await?;
|
||||
|
||||
let (negative_inputs, input_length, max_new_tokens) = self
|
||||
.validate_input(negative_inputs, truncate, Some(max_new_tokens))
|
||||
.await?;
|
||||
|
||||
let logit_biases: Vec<LogitBias> = logit_bias
|
||||
.into_iter()
|
||||
.map(|(word, bias)| LogitBias { word, bias })
|
||||
.collect();
|
||||
|
||||
let parameters = NextTokenChooserParameters {
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
@ -283,6 +295,9 @@ impl Validation {
|
||||
watermark,
|
||||
use_grammar_constraint,
|
||||
grammar,
|
||||
logit_bias: logit_biases,
|
||||
guidance_scale,
|
||||
negative_inputs,
|
||||
};
|
||||
let stopping_parameters = StoppingCriteriaParameters {
|
||||
max_new_tokens,
|
||||
|
@ -7,6 +7,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedModel,
|
||||
)
|
||||
|
||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||
@ -28,10 +29,11 @@ class BloomCausalLMBatch(CausalLMBatch):
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
model: PreTrainedModel,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, model=model, dtype=dtype, device=device)
|
||||
batch.keys_head_dim_last = False
|
||||
return batch
|
||||
|
||||
@ -101,7 +103,7 @@ class BLOOMSharded(CausalLM):
|
||||
return BloomCausalLMBatch
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None #type: ignore
|
||||
):
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
|
@ -3,7 +3,7 @@ import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, PreTrainedModel
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.models import Model
|
||||
@ -69,6 +69,7 @@ class CausalLMBatch(Batch):
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
model: PreTrainedModel,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
@ -87,9 +88,9 @@ class CausalLMBatch(Batch):
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
r.stopping_parameters, tokenizer, model
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
@ -258,7 +259,7 @@ class CausalLMBatch(Batch):
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "CausalLMBatch":
|
||||
def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "CausalLMBatch":
|
||||
# Used for padding
|
||||
total_batch_size = 0
|
||||
max_input_length = 0
|
||||
@ -545,7 +546,7 @@ class CausalLM(Model):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None # type: ignore
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Model Forward
|
||||
kwargs = {
|
||||
|
@ -8,7 +8,7 @@ import numpy as np
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers import PreTrainedTokenizerBase, PreTrainedModel
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.models import Model
|
||||
@ -108,6 +108,7 @@ class FlashCausalLMBatch(Batch):
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
model: PreTrainedModel,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
@ -181,7 +182,7 @@ class FlashCausalLMBatch(Batch):
|
||||
next_token_chooser_parameters.append(r.parameters)
|
||||
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
r.stopping_parameters, tokenizer, model
|
||||
)
|
||||
max_new_tokens = stopping_criteria.max_new_tokens
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
@ -233,7 +234,7 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype, device, tokenizer
|
||||
next_token_chooser_parameters, dtype, device, tokenizer, model
|
||||
)
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
@ -468,7 +469,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "FlashCausalLMBatch":
|
||||
def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "FlashCausalLMBatch":
|
||||
# Batch attributes
|
||||
requests = []
|
||||
requests_idx_mapping = {}
|
||||
@ -590,6 +591,7 @@ class FlashCausalLMBatch(Batch):
|
||||
dtype=batches[0].next_token_chooser.dtype,
|
||||
device=batches[0].next_token_chooser.device,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
)
|
||||
|
||||
speculative_ids = (
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers import PreTrainedTokenizerBase, PreTrainedModel
|
||||
from transformers.models.llama import LlamaTokenizerFast
|
||||
from typing import Optional, Tuple, Type, List
|
||||
|
||||
@ -48,6 +48,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
model: PreTrainedModel,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
@ -124,7 +125,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
next_token_chooser_parameters.append(r.parameters)
|
||||
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
r.stopping_parameters, tokenizer, model
|
||||
)
|
||||
max_new_tokens = stopping_criteria.max_new_tokens
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
@ -190,7 +191,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype, device, tokenizer
|
||||
next_token_chooser_parameters, dtype, device, tokenizer, model
|
||||
)
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
|
@ -8,6 +8,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedModel,
|
||||
)
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
@ -73,6 +74,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
model: PreTrainedModel,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "GalacticaCausalLMBatch":
|
||||
@ -92,9 +94,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
requests_idx_mapping[r.id] = i
|
||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
r.stopping_parameters, tokenizer, model
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
@ -225,7 +227,7 @@ class GalacticaSharded(CausalLM):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None #type: ignore
|
||||
):
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
|
@ -7,6 +7,7 @@ from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedModel,
|
||||
ProcessorMixin,
|
||||
)
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
@ -96,6 +97,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
model: PreTrainedModel,
|
||||
processor: ProcessorMixin, # Hack
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
@ -114,9 +116,9 @@ class IdeficsCausalLMBatch(Batch):
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
r.stopping_parameters, tokenizer, model
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
@ -642,7 +644,7 @@ class IdeficsCausalLM(Model):
|
||||
pixel_values,
|
||||
image_hidden_states,
|
||||
image_attention_mask,
|
||||
past_key_values: Optional = None,
|
||||
past_key_values: Optional = None, #type: ignore
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Model Forward
|
||||
kwargs = {
|
||||
|
@ -4,7 +4,7 @@ import torch.distributed
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
|
||||
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
|
||||
@ -29,10 +29,11 @@ class MPTCausalLMBatch(CausalLMBatch):
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
model: PreTrainedModel,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device, model=model)
|
||||
batch.keys_head_dim_last = False
|
||||
return batch
|
||||
|
||||
|
@ -3,7 +3,7 @@ import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase, PreTrainedModel
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
@ -75,6 +75,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
model: PreTrainedModel,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "Seq2SeqLMBatch":
|
||||
@ -96,9 +97,9 @@ class Seq2SeqLMBatch(Batch):
|
||||
inputs.append(r.inputs)
|
||||
requests_idx_mapping[r.id] = i
|
||||
decoder_input_lengths.append(1)
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
r.stopping_parameters, tokenizer, model
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
@ -278,7 +279,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
|
||||
def concatenate(cls, batches: List["Seq2SeqLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "Seq2SeqLMBatch":
|
||||
"""Concatenate multiple batches together by padding internal torch tensors"""
|
||||
|
||||
# Used for padding
|
||||
@ -587,9 +588,9 @@ class Seq2SeqLM(Model):
|
||||
input_ids,
|
||||
attention_mask,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask: Optional,
|
||||
encoder_last_hidden_state: Optional,
|
||||
past_key_values: Optional = None,
|
||||
decoder_attention_mask: Optional, #type: ignore
|
||||
encoder_last_hidden_state: Optional, #type: ignore
|
||||
past_key_values: Optional = None, #type: ignore
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
|
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers import PreTrainedTokenizerBase, PreTrainedModel
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||
@ -22,6 +22,7 @@ class Batch(ABC):
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
model: PreTrainedModel,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "Batch":
|
||||
@ -33,7 +34,7 @@ class Batch(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
||||
def concatenate(cls, batches: List["Batch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "Batch":
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
@ -85,12 +85,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
request.batch,
|
||||
self.model.tokenizer,
|
||||
self.model.processor,
|
||||
self.model,
|
||||
self.model.dtype,
|
||||
self.model.device,
|
||||
)
|
||||
else:
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
request.batch, self.model.tokenizer, self.model, self.model.dtype, self.model.device
|
||||
)
|
||||
max_supported_total_tokens = self.model.warmup(batch)
|
||||
|
||||
@ -107,12 +108,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
request.batch,
|
||||
self.model.tokenizer,
|
||||
self.model.processor,
|
||||
self.model,
|
||||
self.model.dtype,
|
||||
self.model.device,
|
||||
)
|
||||
else:
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
request.batch, self.model.tokenizer, self.model, self.model.dtype, self.model.device
|
||||
)
|
||||
|
||||
generations, next_batch, timings = self.model.generate_token(batch)
|
||||
@ -143,7 +145,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
if len(batches) > 1:
|
||||
start_concat = time.time_ns()
|
||||
batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer)
|
||||
batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer, model=self.model)
|
||||
concat_ns = time.time_ns() - start_concat
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
@ -14,7 +14,7 @@ from text_generation_server.utils.logits_process import (
|
||||
static_warper,
|
||||
)
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||
from transformers import PreTrainedTokenizerBase,RepetitionPenaltyLogitsProcessor,UnbatchedClassifierFreeGuidanceLogitsProcessor,PreTrainedModel
|
||||
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
|
||||
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
|
||||
|
||||
@ -32,8 +32,11 @@ class NextTokenChooser:
|
||||
seed=0,
|
||||
device="cpu",
|
||||
tokenizer=None,
|
||||
model=None,
|
||||
use_grammar_constraint=False,
|
||||
grammar="",
|
||||
guidance_scale=1.0,
|
||||
negative_inputs="",
|
||||
):
|
||||
self.watermark_processor = (
|
||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||
@ -50,6 +53,19 @@ class NextTokenChooser:
|
||||
else:
|
||||
self.grammar_processor = None
|
||||
|
||||
if guidance_scale != 1.0 and model is not None and negative_inputs:
|
||||
negative_inputs_t = tokenizer([negative_inputs], return_tensors="pt")
|
||||
device = next(model.model.parameters()).device
|
||||
self.guidance_scale_processor = UnbatchedClassifierFreeGuidanceLogitsProcessor(
|
||||
guidance_scale,
|
||||
model.model,
|
||||
unconditional_ids=negative_inputs_t["input_ids"].to(device),
|
||||
unconditional_attention_mask=negative_inputs_t["attention_mask"].to(device),
|
||||
use_cache=True, # use cache permanently on for now.
|
||||
)
|
||||
else:
|
||||
self.guidance_scale_processor = None
|
||||
|
||||
has_warpers = (
|
||||
(temperature is not None and temperature != 1.0)
|
||||
or (top_k is not None and top_k != 0)
|
||||
@ -67,12 +83,15 @@ class NextTokenChooser:
|
||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
if self.guidance_scale_processor is not None:
|
||||
scores = self.guidance_scale_processor(input_ids, scores)
|
||||
if self.grammar_processor is not None:
|
||||
scores = self.grammar_processor(input_ids, scores)
|
||||
if self.watermark_processor is not None:
|
||||
scores = self.watermark_processor(input_ids, scores)
|
||||
if self.repetition_processor is not None:
|
||||
scores = self.repetition_processor(input_ids, scores)
|
||||
if self.grammar_processor is not None:
|
||||
scores = self.grammar_processor(input_ids, scores)
|
||||
|
||||
|
||||
if self.static_warper is None:
|
||||
next_logprob = torch.log_softmax(scores, -1)
|
||||
@ -89,6 +108,7 @@ class NextTokenChooser:
|
||||
pb: generate_pb2.NextTokenChooserParameters,
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
model: PreTrainedModel,
|
||||
) -> "NextTokenChooser":
|
||||
return NextTokenChooser(
|
||||
watermark=pb.watermark,
|
||||
@ -103,6 +123,9 @@ class NextTokenChooser:
|
||||
use_grammar_constraint=pb.use_grammar_constraint,
|
||||
grammar=pb.grammar,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
guidance_scale=pb.guidance_scale,
|
||||
negative_inputs=pb.negative_inputs,
|
||||
)
|
||||
|
||||
|
||||
@ -157,6 +180,7 @@ class StoppingCriteria:
|
||||
cls,
|
||||
pb: generate_pb2.StoppingCriteriaParameters,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
model: PreTrainedModel,
|
||||
) -> "StoppingCriteria":
|
||||
stop_sequence_criterias = [
|
||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||
@ -199,9 +223,12 @@ class HeterogeneousNextTokenChooser:
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
model: PreTrainedModel,
|
||||
watermark: List[bool],
|
||||
use_grammar_constraint: List[bool],
|
||||
grammar: List[str],
|
||||
guidance_scale: List[float],
|
||||
negative_inputs: List[str],
|
||||
temperature: List[float],
|
||||
repetition_penalty: List[float],
|
||||
top_k: List[int],
|
||||
@ -232,10 +259,14 @@ class HeterogeneousNextTokenChooser:
|
||||
else None
|
||||
)
|
||||
|
||||
if use_grammar_constraint:
|
||||
grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer)
|
||||
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
|
||||
warpers.append(grammar_processor)
|
||||
if any(use_grammar_constraint):
|
||||
grammar_processors = {
|
||||
i: GrammarConstrainedLogitsProcessor(IncrementalGrammarConstraint(grammar[i], "root", tokenizer))
|
||||
for i, use_gc in enumerate(use_grammar_constraint) if use_gc
|
||||
}
|
||||
self.grammar_processor = HeterogeneousProcessorWrapper(grammar_processors)
|
||||
else:
|
||||
self.grammar_processor = None
|
||||
|
||||
if any([x != 1.0 for x in temperature]):
|
||||
do_sample = [
|
||||
@ -294,6 +325,8 @@ class HeterogeneousNextTokenChooser:
|
||||
_scores = self.watermark_processor(input_ids, _scores)
|
||||
if self.repetition_processor is not None:
|
||||
_scores = self.repetition_processor(input_ids, _scores)
|
||||
if self.grammar_processor is not None:
|
||||
_scores = self.grammar_processor(input_ids, _scores)
|
||||
|
||||
for warper in self.warpers:
|
||||
_scores = warper(input_ids, _scores)
|
||||
@ -385,6 +418,7 @@ class HeterogeneousNextTokenChooser:
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
model: PreTrainedModel,
|
||||
) -> "HeterogeneousNextTokenChooser":
|
||||
return HeterogeneousNextTokenChooser(
|
||||
watermark=[pb_.watermark for pb_ in pb],
|
||||
@ -398,8 +432,11 @@ class HeterogeneousNextTokenChooser:
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
tokenizer=tokenizer,
|
||||
use_grammar_constraint=use_grammar_constraint,
|
||||
grammar=grammar,
|
||||
model=model,
|
||||
use_grammar_constraint=[pb_.use_grammar_constraint for pb_ in pb],
|
||||
grammar=[pb_.grammar for pb_ in pb],
|
||||
guidance_scale=[pb_.guidance_scale for pb_ in pb],
|
||||
negative_inputs=[pb_.negative_inputs for pb_ in pb],
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user