mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
adding guidance and extra parameters for token bias
This commit is contained in:
parent
1108560745
commit
d84b38e30d
@ -51,6 +51,11 @@ message ClearCacheRequest {
|
|||||||
/// Empty response
|
/// Empty response
|
||||||
message ClearCacheResponse {}
|
message ClearCacheResponse {}
|
||||||
|
|
||||||
|
message LogitBias {
|
||||||
|
string word = 1;
|
||||||
|
float bias = 2;
|
||||||
|
}
|
||||||
|
|
||||||
message NextTokenChooserParameters {
|
message NextTokenChooserParameters {
|
||||||
/// exponential scaling output probability distribution
|
/// exponential scaling output probability distribution
|
||||||
float temperature = 1;
|
float temperature = 1;
|
||||||
@ -70,6 +75,9 @@ message NextTokenChooserParameters {
|
|||||||
bool watermark = 8;
|
bool watermark = 8;
|
||||||
bool use_grammar_constraint = 9;
|
bool use_grammar_constraint = 9;
|
||||||
string grammar = 10;
|
string grammar = 10;
|
||||||
|
repeated LogitBias logit_bias = 11;
|
||||||
|
float guidance_scale = 12;
|
||||||
|
string negative_inputs = 13;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoppingCriteriaParameters {
|
message StoppingCriteriaParameters {
|
||||||
|
@ -128,6 +128,10 @@ impl Client {
|
|||||||
watermark: true,
|
watermark: true,
|
||||||
use_grammar_constraint: false,
|
use_grammar_constraint: false,
|
||||||
grammar: "".to_string(),
|
grammar: "".to_string(),
|
||||||
|
logit_bias: Vec::new(),
|
||||||
|
guidance_scale: 1.0,
|
||||||
|
negative_inputs: "".to_string(),
|
||||||
|
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: max_total_tokens - truncate,
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
|
@ -10,7 +10,7 @@ pub use pb::generate::v2::HealthResponse;
|
|||||||
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
||||||
pub use pb::generate::v2::{
|
pub use pb::generate::v2::{
|
||||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
||||||
Request, StoppingCriteriaParameters, Tokens,
|
Request, StoppingCriteriaParameters, Tokens,LogitBias
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
@ -46,6 +46,9 @@ impl Health {
|
|||||||
watermark: false,
|
watermark: false,
|
||||||
use_grammar_constraint: false,
|
use_grammar_constraint: false,
|
||||||
grammar: "".to_string(),
|
grammar: "".to_string(),
|
||||||
|
logit_bias: Vec::new(),
|
||||||
|
guidance_scale: 1.0,
|
||||||
|
negative_inputs: "".to_string(),
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
@ -122,12 +122,21 @@ pub(crate) struct GenerateParameters {
|
|||||||
#[schema(default = "false", example = true)]
|
#[schema(default = "false", example = true)]
|
||||||
pub watermark: bool,
|
pub watermark: bool,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "true")]
|
#[schema(default = "false")]
|
||||||
pub use_grammar_constraint: bool,
|
pub use_grammar_constraint: bool,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "false")]
|
#[schema(default = "")]
|
||||||
pub grammar: String,
|
pub grammar: String,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(default = "null", nullable = true)]
|
||||||
|
pub logit_bias: Vec<(String, f32)>,
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(default = "1.0")]
|
||||||
|
pub guidance_scale: f32,
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(default = "")]
|
||||||
|
pub negative_inputs: String,
|
||||||
|
#[serde(default)]
|
||||||
#[schema(default = "")]
|
#[schema(default = "")]
|
||||||
pub details: bool,
|
pub details: bool,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@ -166,6 +175,9 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
watermark: false,
|
watermark: false,
|
||||||
use_grammar_constraint: false,
|
use_grammar_constraint: false,
|
||||||
grammar: "".to_string(),
|
grammar: "".to_string(),
|
||||||
|
logit_bias: Vec::new(),
|
||||||
|
guidance_scale: 1.0,
|
||||||
|
negative_inputs: "".to_string(),
|
||||||
details: false,
|
details: false,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
|
@ -356,6 +356,11 @@ mod tests {
|
|||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 0.0,
|
repetition_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
|
use_grammar_constraint: false,
|
||||||
|
grammar: "".to_string(),
|
||||||
|
logit_bias: Vec::new(),
|
||||||
|
guidance_scale: 1.0,
|
||||||
|
negative_inputs: "".to_string(),
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters, LogitBias};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokenizers::TruncationDirection;
|
use tokenizers::TruncationDirection;
|
||||||
@ -165,6 +165,9 @@ impl Validation {
|
|||||||
watermark,
|
watermark,
|
||||||
use_grammar_constraint,
|
use_grammar_constraint,
|
||||||
grammar,
|
grammar,
|
||||||
|
guidance_scale,
|
||||||
|
negative_inputs,
|
||||||
|
logit_bias,
|
||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
..
|
..
|
||||||
@ -268,10 +271,19 @@ impl Validation {
|
|||||||
.unwrap_or(Ok(None))?;
|
.unwrap_or(Ok(None))?;
|
||||||
|
|
||||||
// Validate inputs
|
// Validate inputs
|
||||||
let (inputs, input_length, max_new_tokens) = self
|
let (inputs, _input_length, max_new_tokens) = self
|
||||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
let (negative_inputs, input_length, max_new_tokens) = self
|
||||||
|
.validate_input(negative_inputs, truncate, Some(max_new_tokens))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let logit_biases: Vec<LogitBias> = logit_bias
|
||||||
|
.into_iter()
|
||||||
|
.map(|(word, bias)| LogitBias { word, bias })
|
||||||
|
.collect();
|
||||||
|
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
@ -283,6 +295,9 @@ impl Validation {
|
|||||||
watermark,
|
watermark,
|
||||||
use_grammar_constraint,
|
use_grammar_constraint,
|
||||||
grammar,
|
grammar,
|
||||||
|
logit_bias: logit_biases,
|
||||||
|
guidance_scale,
|
||||||
|
negative_inputs,
|
||||||
};
|
};
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
@ -7,6 +7,7 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
|
PreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||||
@ -28,10 +29,11 @@ class BloomCausalLMBatch(CausalLMBatch):
|
|||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
model: PreTrainedModel,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
batch = super().from_pb(pb=pb, tokenizer=tokenizer, model=model, dtype=dtype, device=device)
|
||||||
batch.keys_head_dim_last = False
|
batch.keys_head_dim_last = False
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@ -101,7 +103,7 @@ class BLOOMSharded(CausalLM):
|
|||||||
return BloomCausalLMBatch
|
return BloomCausalLMBatch
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None #type: ignore
|
||||||
):
|
):
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
@ -3,7 +3,7 @@ import time
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, PreTrainedModel
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
@ -69,6 +69,7 @@ class CausalLMBatch(Batch):
|
|||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
model: PreTrainedModel,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
@ -87,9 +88,9 @@ class CausalLMBatch(Batch):
|
|||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer, model
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
top_n_tokens.append(r.top_n_tokens)
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
@ -258,7 +259,7 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "CausalLMBatch":
|
def concatenate(cls, batches: List["CausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "CausalLMBatch":
|
||||||
# Used for padding
|
# Used for padding
|
||||||
total_batch_size = 0
|
total_batch_size = 0
|
||||||
max_input_length = 0
|
max_input_length = 0
|
||||||
@ -545,7 +546,7 @@ class CausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None # type: ignore
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase, PreTrainedModel
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
@ -108,6 +108,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
model: PreTrainedModel,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
@ -181,7 +182,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
next_token_chooser_parameters.append(r.parameters)
|
next_token_chooser_parameters.append(r.parameters)
|
||||||
|
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer, model
|
||||||
)
|
)
|
||||||
max_new_tokens = stopping_criteria.max_new_tokens
|
max_new_tokens = stopping_criteria.max_new_tokens
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
@ -233,7 +234,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device, tokenizer
|
next_token_chooser_parameters, dtype, device, tokenizer, model
|
||||||
)
|
)
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
@ -468,7 +469,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None) -> "FlashCausalLMBatch":
|
def concatenate(cls, batches: List["FlashCausalLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "FlashCausalLMBatch":
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
@ -590,6 +591,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
dtype=batches[0].next_token_chooser.dtype,
|
dtype=batches[0].next_token_chooser.dtype,
|
||||||
device=batches[0].next_token_chooser.device,
|
device=batches[0].next_token_chooser.device,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_ids = (
|
speculative_ids = (
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase, PreTrainedModel
|
||||||
from transformers.models.llama import LlamaTokenizerFast
|
from transformers.models.llama import LlamaTokenizerFast
|
||||||
from typing import Optional, Tuple, Type, List
|
from typing import Optional, Tuple, Type, List
|
||||||
|
|
||||||
@ -48,6 +48,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
model: PreTrainedModel,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
@ -124,7 +125,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
next_token_chooser_parameters.append(r.parameters)
|
next_token_chooser_parameters.append(r.parameters)
|
||||||
|
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer, model
|
||||||
)
|
)
|
||||||
max_new_tokens = stopping_criteria.max_new_tokens
|
max_new_tokens = stopping_criteria.max_new_tokens
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
@ -190,7 +191,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device, tokenizer
|
next_token_chooser_parameters, dtype, device, tokenizer, model
|
||||||
)
|
)
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
|
PreTrainedModel,
|
||||||
)
|
)
|
||||||
from text_generation_server.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
@ -73,6 +74,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
model: PreTrainedModel,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "GalacticaCausalLMBatch":
|
) -> "GalacticaCausalLMBatch":
|
||||||
@ -92,9 +94,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer, model
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
top_n_tokens.append(r.top_n_tokens)
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
@ -225,7 +227,7 @@ class GalacticaSharded(CausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None #type: ignore
|
||||||
):
|
):
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
@ -7,6 +7,7 @@ from transformers import (
|
|||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
|
PreTrainedModel,
|
||||||
ProcessorMixin,
|
ProcessorMixin,
|
||||||
)
|
)
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
@ -96,6 +97,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
model: PreTrainedModel,
|
||||||
processor: ProcessorMixin, # Hack
|
processor: ProcessorMixin, # Hack
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
@ -114,9 +116,9 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer, model
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
@ -642,7 +644,7 @@ class IdeficsCausalLM(Model):
|
|||||||
pixel_values,
|
pixel_values,
|
||||||
image_hidden_states,
|
image_hidden_states,
|
||||||
image_attention_mask,
|
image_attention_mask,
|
||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None, #type: ignore
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
@ -4,7 +4,7 @@ import torch.distributed
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase, PreTrainedModel
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@ -29,10 +29,11 @@ class MPTCausalLMBatch(CausalLMBatch):
|
|||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
model: PreTrainedModel,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device, model=model)
|
||||||
batch.keys_head_dim_last = False
|
batch.keys_head_dim_last = False
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import time
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase, PreTrainedModel
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
@ -75,6 +75,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
model: PreTrainedModel,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "Seq2SeqLMBatch":
|
) -> "Seq2SeqLMBatch":
|
||||||
@ -96,9 +97,9 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer, model))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer, model
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
top_n_tokens.append(r.top_n_tokens)
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
@ -278,7 +279,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
|
def concatenate(cls, batches: List["Seq2SeqLMBatch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "Seq2SeqLMBatch":
|
||||||
"""Concatenate multiple batches together by padding internal torch tensors"""
|
"""Concatenate multiple batches together by padding internal torch tensors"""
|
||||||
|
|
||||||
# Used for padding
|
# Used for padding
|
||||||
@ -587,9 +588,9 @@ class Seq2SeqLM(Model):
|
|||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
decoder_attention_mask: Optional,
|
decoder_attention_mask: Optional, #type: ignore
|
||||||
encoder_last_hidden_state: Optional,
|
encoder_last_hidden_state: Optional, #type: ignore
|
||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None, #type: ignore
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
|
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase, PreTrainedModel
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
@ -22,6 +22,7 @@ class Batch(ABC):
|
|||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
model: PreTrainedModel,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "Batch":
|
) -> "Batch":
|
||||||
@ -33,7 +34,7 @@ class Batch(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
def concatenate(cls, batches: List["Batch"], tokenizer: Optional[PreTrainedTokenizerBase] = None, model: Optional[PreTrainedModel] = None) -> "Batch":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -85,12 +85,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
request.batch,
|
request.batch,
|
||||||
self.model.tokenizer,
|
self.model.tokenizer,
|
||||||
self.model.processor,
|
self.model.processor,
|
||||||
|
self.model,
|
||||||
self.model.dtype,
|
self.model.dtype,
|
||||||
self.model.device,
|
self.model.device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
max_supported_total_tokens = self.model.warmup(batch)
|
max_supported_total_tokens = self.model.warmup(batch)
|
||||||
|
|
||||||
@ -107,12 +108,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
request.batch,
|
request.batch,
|
||||||
self.model.tokenizer,
|
self.model.tokenizer,
|
||||||
self.model.processor,
|
self.model.processor,
|
||||||
|
self.model,
|
||||||
self.model.dtype,
|
self.model.dtype,
|
||||||
self.model.device,
|
self.model.device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
|
|
||||||
generations, next_batch, timings = self.model.generate_token(batch)
|
generations, next_batch, timings = self.model.generate_token(batch)
|
||||||
@ -143,7 +145,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
start_concat = time.time_ns()
|
start_concat = time.time_ns()
|
||||||
batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer)
|
batch = self.model.batch_type.concatenate(batches, tokenizer=self.model.tokenizer, model=self.model)
|
||||||
concat_ns = time.time_ns() - start_concat
|
concat_ns = time.time_ns() - start_concat
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
@ -14,7 +14,7 @@ from text_generation_server.utils.logits_process import (
|
|||||||
static_warper,
|
static_warper,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
from transformers import PreTrainedTokenizerBase,RepetitionPenaltyLogitsProcessor,UnbatchedClassifierFreeGuidanceLogitsProcessor,PreTrainedModel
|
||||||
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
|
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
|
||||||
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
|
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
|
||||||
|
|
||||||
@ -32,8 +32,11 @@ class NextTokenChooser:
|
|||||||
seed=0,
|
seed=0,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
|
model=None,
|
||||||
use_grammar_constraint=False,
|
use_grammar_constraint=False,
|
||||||
grammar="",
|
grammar="",
|
||||||
|
guidance_scale=1.0,
|
||||||
|
negative_inputs="",
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
@ -50,6 +53,19 @@ class NextTokenChooser:
|
|||||||
else:
|
else:
|
||||||
self.grammar_processor = None
|
self.grammar_processor = None
|
||||||
|
|
||||||
|
if guidance_scale != 1.0 and model is not None and negative_inputs:
|
||||||
|
negative_inputs_t = tokenizer([negative_inputs], return_tensors="pt")
|
||||||
|
device = next(model.model.parameters()).device
|
||||||
|
self.guidance_scale_processor = UnbatchedClassifierFreeGuidanceLogitsProcessor(
|
||||||
|
guidance_scale,
|
||||||
|
model.model,
|
||||||
|
unconditional_ids=negative_inputs_t["input_ids"].to(device),
|
||||||
|
unconditional_attention_mask=negative_inputs_t["attention_mask"].to(device),
|
||||||
|
use_cache=True, # use cache permanently on for now.
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.guidance_scale_processor = None
|
||||||
|
|
||||||
has_warpers = (
|
has_warpers = (
|
||||||
(temperature is not None and temperature != 1.0)
|
(temperature is not None and temperature != 1.0)
|
||||||
or (top_k is not None and top_k != 0)
|
or (top_k is not None and top_k != 0)
|
||||||
@ -67,12 +83,15 @@ class NextTokenChooser:
|
|||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
|
if self.guidance_scale_processor is not None:
|
||||||
|
scores = self.guidance_scale_processor(input_ids, scores)
|
||||||
|
if self.grammar_processor is not None:
|
||||||
|
scores = self.grammar_processor(input_ids, scores)
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
scores = self.watermark_processor(input_ids, scores)
|
scores = self.watermark_processor(input_ids, scores)
|
||||||
if self.repetition_processor is not None:
|
if self.repetition_processor is not None:
|
||||||
scores = self.repetition_processor(input_ids, scores)
|
scores = self.repetition_processor(input_ids, scores)
|
||||||
if self.grammar_processor is not None:
|
|
||||||
scores = self.grammar_processor(input_ids, scores)
|
|
||||||
|
|
||||||
if self.static_warper is None:
|
if self.static_warper is None:
|
||||||
next_logprob = torch.log_softmax(scores, -1)
|
next_logprob = torch.log_softmax(scores, -1)
|
||||||
@ -89,6 +108,7 @@ class NextTokenChooser:
|
|||||||
pb: generate_pb2.NextTokenChooserParameters,
|
pb: generate_pb2.NextTokenChooserParameters,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
model: PreTrainedModel,
|
||||||
) -> "NextTokenChooser":
|
) -> "NextTokenChooser":
|
||||||
return NextTokenChooser(
|
return NextTokenChooser(
|
||||||
watermark=pb.watermark,
|
watermark=pb.watermark,
|
||||||
@ -103,6 +123,9 @@ class NextTokenChooser:
|
|||||||
use_grammar_constraint=pb.use_grammar_constraint,
|
use_grammar_constraint=pb.use_grammar_constraint,
|
||||||
grammar=pb.grammar,
|
grammar=pb.grammar,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
model=model,
|
||||||
|
guidance_scale=pb.guidance_scale,
|
||||||
|
negative_inputs=pb.negative_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -157,6 +180,7 @@ class StoppingCriteria:
|
|||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.StoppingCriteriaParameters,
|
pb: generate_pb2.StoppingCriteriaParameters,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
model: PreTrainedModel,
|
||||||
) -> "StoppingCriteria":
|
) -> "StoppingCriteria":
|
||||||
stop_sequence_criterias = [
|
stop_sequence_criterias = [
|
||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
@ -199,9 +223,12 @@ class HeterogeneousNextTokenChooser:
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
model: PreTrainedModel,
|
||||||
watermark: List[bool],
|
watermark: List[bool],
|
||||||
use_grammar_constraint: List[bool],
|
use_grammar_constraint: List[bool],
|
||||||
grammar: List[str],
|
grammar: List[str],
|
||||||
|
guidance_scale: List[float],
|
||||||
|
negative_inputs: List[str],
|
||||||
temperature: List[float],
|
temperature: List[float],
|
||||||
repetition_penalty: List[float],
|
repetition_penalty: List[float],
|
||||||
top_k: List[int],
|
top_k: List[int],
|
||||||
@ -232,10 +259,14 @@ class HeterogeneousNextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_grammar_constraint:
|
if any(use_grammar_constraint):
|
||||||
grammar = IncrementalGrammarConstraint(grammar, "root", tokenizer)
|
grammar_processors = {
|
||||||
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
|
i: GrammarConstrainedLogitsProcessor(IncrementalGrammarConstraint(grammar[i], "root", tokenizer))
|
||||||
warpers.append(grammar_processor)
|
for i, use_gc in enumerate(use_grammar_constraint) if use_gc
|
||||||
|
}
|
||||||
|
self.grammar_processor = HeterogeneousProcessorWrapper(grammar_processors)
|
||||||
|
else:
|
||||||
|
self.grammar_processor = None
|
||||||
|
|
||||||
if any([x != 1.0 for x in temperature]):
|
if any([x != 1.0 for x in temperature]):
|
||||||
do_sample = [
|
do_sample = [
|
||||||
@ -294,6 +325,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
_scores = self.watermark_processor(input_ids, _scores)
|
_scores = self.watermark_processor(input_ids, _scores)
|
||||||
if self.repetition_processor is not None:
|
if self.repetition_processor is not None:
|
||||||
_scores = self.repetition_processor(input_ids, _scores)
|
_scores = self.repetition_processor(input_ids, _scores)
|
||||||
|
if self.grammar_processor is not None:
|
||||||
|
_scores = self.grammar_processor(input_ids, _scores)
|
||||||
|
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
_scores = warper(input_ids, _scores)
|
_scores = warper(input_ids, _scores)
|
||||||
@ -385,6 +418,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
model: PreTrainedModel,
|
||||||
) -> "HeterogeneousNextTokenChooser":
|
) -> "HeterogeneousNextTokenChooser":
|
||||||
return HeterogeneousNextTokenChooser(
|
return HeterogeneousNextTokenChooser(
|
||||||
watermark=[pb_.watermark for pb_ in pb],
|
watermark=[pb_.watermark for pb_ in pb],
|
||||||
@ -398,8 +432,11 @@ class HeterogeneousNextTokenChooser:
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
use_grammar_constraint=use_grammar_constraint,
|
model=model,
|
||||||
grammar=grammar,
|
use_grammar_constraint=[pb_.use_grammar_constraint for pb_ in pb],
|
||||||
|
grammar=[pb_.grammar for pb_ in pb],
|
||||||
|
guidance_scale=[pb_.guidance_scale for pb_ in pb],
|
||||||
|
negative_inputs=[pb_.negative_inputs for pb_ in pb],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user