mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-08 18:32:06 +00:00
fix: improve processor logic and refactor
This commit is contained in:
parent
bb5c875f0b
commit
9eeccbf9a5
@ -1192,7 +1192,6 @@ pub(crate) async fn chat_completions(
|
||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||
chat.clone().try_into_generate(&infer)?;
|
||||
span.record("parameters", format!("{:?}", generate_request.parameters));
|
||||
println!("ChatRequest: {:#?}", generate_request);
|
||||
let logprobs = logprobs.unwrap_or_default();
|
||||
|
||||
// extract model id from request if specified
|
||||
|
@ -34,6 +34,7 @@ pub struct Validation {
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
disable_grammar_support: bool,
|
||||
vocab_size: u32,
|
||||
/// Channel to communicate with the background tokenization task
|
||||
sender: mpsc::UnboundedSender<TokenizerRequest>,
|
||||
}
|
||||
@ -88,6 +89,19 @@ impl Validation {
|
||||
validation_sender
|
||||
};
|
||||
|
||||
let vocab_size = match &tokenizer {
|
||||
Tokenizer::Python { tokenizer_name, .. } => {
|
||||
warn!(
|
||||
"Tokenizer {} is not supported for validation",
|
||||
tokenizer_name
|
||||
);
|
||||
0
|
||||
}
|
||||
Tokenizer::Rust(tokenizer) => tokenizer.get_vocab_size(false),
|
||||
}
|
||||
.try_into()
|
||||
.unwrap_or(0);
|
||||
|
||||
Self {
|
||||
max_best_of,
|
||||
sender,
|
||||
@ -96,6 +110,7 @@ impl Validation {
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
disable_grammar_support,
|
||||
vocab_size,
|
||||
}
|
||||
}
|
||||
|
||||
@ -409,6 +424,35 @@ impl Validation {
|
||||
None => None,
|
||||
};
|
||||
|
||||
let logit_bias = match &request.parameters.logit_bias {
|
||||
Some(bias) if !bias.is_empty() => {
|
||||
for (token_str, _) in bias.iter() {
|
||||
let token_id = token_str.parse::<u32>().map_err(|_| {
|
||||
ValidationError::LogitBiasInvalid(format!(
|
||||
"Token ID {} is not a valid number.",
|
||||
token_str
|
||||
))
|
||||
})?;
|
||||
|
||||
if token_id >= self.vocab_size {
|
||||
return Err(ValidationError::LogitBiasInvalid(format!(
|
||||
"Token ID {} is out of range. Must be between 0 and {}.",
|
||||
token_id,
|
||||
self.vocab_size - 1
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Transform into the required format
|
||||
Some(
|
||||
bias.iter()
|
||||
.map(|(k, v)| (k.parse::<u32>().unwrap(), *v as f32))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let parameters = ValidParameters {
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
@ -420,18 +464,7 @@ impl Validation {
|
||||
seed,
|
||||
watermark,
|
||||
grammar,
|
||||
logit_bias: Some(
|
||||
request
|
||||
.parameters
|
||||
.logit_bias
|
||||
.iter()
|
||||
.flat_map(|bias| {
|
||||
bias.iter()
|
||||
.map(|(k, v)| (k.parse::<u32>().unwrap(), *v as f32))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
logit_bias,
|
||||
};
|
||||
let stopping_parameters = ValidStoppingParameters {
|
||||
max_new_tokens,
|
||||
@ -1011,6 +1044,8 @@ pub enum ValidationError {
|
||||
FailedFetchImage(#[from] reqwest::Error),
|
||||
#[error("{0} modality is not supported")]
|
||||
UnsupportedModality(&'static str),
|
||||
#[error("logit_bias is not valid: {0}")]
|
||||
LogitBiasInvalid(String),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -625,55 +625,49 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
return self
|
||||
|
||||
|
||||
class LogitBiasProcessor:
|
||||
"""Process logits with logit biases."""
|
||||
class LogitBiasProcessor(LogitsProcessor):
|
||||
"""
|
||||
`LogitsProcessor` creates a bias tensor from a dictionary of token IDs and their
|
||||
corresponding bias values. Bias are applied to the logits during each forward pass.
|
||||
|
||||
Supports token IDs provided as strings (e.g., {"9707": -100}).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, logit_biases: Optional[dict], tokenizer: PreTrainedTokenizerBase
|
||||
self,
|
||||
logit_biases: dict,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.logit_biases = logit_biases or {}
|
||||
assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases"
|
||||
|
||||
# Pre-compute token IDs for each token string
|
||||
self.token_id_mapping = {}
|
||||
vocab_size = len(tokenizer)
|
||||
|
||||
# Convert keys to integers and values to a list
|
||||
token_ids = torch.tensor(
|
||||
[int(k) for k in logit_biases.keys()], dtype=torch.long
|
||||
)
|
||||
bias_values = torch.tensor(list(logit_biases.values()), dtype=torch.float)
|
||||
|
||||
# Create a tensor and directly copy bias values at the corresponding indices
|
||||
self.bias_tensor = torch.zeros(vocab_size, dtype=torch.float)
|
||||
self.bias_tensor.index_put_((token_ids,), bias_values, accumulate=True)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
# If no logit biases, return scores unchanged
|
||||
if not self.logit_biases:
|
||||
return scores
|
||||
|
||||
# Apply bias to the corresponding scores
|
||||
for token_str, bias_value in self.logit_biases.items():
|
||||
# Get token ID, either from cache or by computing it
|
||||
if token_str not in self.token_id_mapping:
|
||||
if token_str.isdigit():
|
||||
# If the token string is already a numeric ID
|
||||
token_id = int(token_str)
|
||||
else:
|
||||
# Otherwise, use the tokenizer to get the ID
|
||||
tokens = self.tokenizer.encode(token_str, add_special_tokens=False)
|
||||
token_id = tokens[0] if tokens else -1 # Use -1 for not found
|
||||
|
||||
self.token_id_mapping[token_str] = token_id
|
||||
|
||||
token_id = self.token_id_mapping[token_str]
|
||||
|
||||
# Apply bias if token ID is valid
|
||||
if 0 <= token_id < scores.size(-1):
|
||||
scores[:, token_id] += bias_value
|
||||
|
||||
# Apply bias tensor as a broadcasted addition
|
||||
if self.bias_tensor.shape[0] != scores.shape[1]:
|
||||
# Fix if the bias tensor is smaller than the scores
|
||||
self.bias_tensor = torch.nn.functional.pad(
|
||||
self.bias_tensor, (0, scores.shape[1] - self.bias_tensor.shape[0])
|
||||
)
|
||||
scores.add_(self.bias_tensor.to(device=scores.device, dtype=scores.dtype))
|
||||
return scores
|
||||
|
||||
def filter(self, indices):
|
||||
"""Keep only the logit biases for the specified indices."""
|
||||
new_logit_biases = {
|
||||
k: self.logit_biases[k] for k in indices if k in self.logit_biases
|
||||
}
|
||||
return LogitBiasProcessor(new_logit_biases, self.tokenizer)
|
||||
|
||||
|
||||
class HeterogeneousLogitBiasProcessor:
|
||||
"""Process logits with different logit biases for each sequence in the batch."""
|
||||
class HeterogeneousLogitBiasProcessor(LogitsProcessor):
|
||||
"""
|
||||
Process logits with different logit biases for each sequence in the batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -681,54 +675,42 @@ class HeterogeneousLogitBiasProcessor:
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
):
|
||||
self.device = device
|
||||
self.tokenizer = tokenizer
|
||||
self.logit_biases = logit_biases
|
||||
self.batch_size = len(logit_biases)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
self.vocab_size = len(tokenizer)
|
||||
|
||||
# Pre-compute token IDs for each token string
|
||||
self.token_id_mapping = {}
|
||||
# Create batch_size x vocab_size bias matrix
|
||||
self.bias_matrix = torch.zeros(
|
||||
(len(logit_biases), self.vocab_size), dtype=torch.float, device=device
|
||||
)
|
||||
|
||||
# Create a mapping of indices that have logit biases
|
||||
self.indices_with_biases = {
|
||||
i: bias_dict
|
||||
for i, bias_dict in enumerate(self.logit_biases)
|
||||
if bias_dict is not None and len(bias_dict) > 0
|
||||
}
|
||||
# for each logit bias dictionary, convert keys to integers and values to a list
|
||||
for i, logit_bias in enumerate(logit_biases):
|
||||
token_ids = torch.tensor(
|
||||
[int(k) for k in logit_bias.keys()], dtype=torch.long
|
||||
).to(device=device)
|
||||
bias_values = torch.tensor(list(logit_bias.values()), dtype=torch.float).to(
|
||||
device=device
|
||||
)
|
||||
# Create a tensor and directly copy bias values at the corresponding indices
|
||||
self.bias_matrix[i].index_put_((token_ids,), bias_values, accumulate=True)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
# If no indices have biases, return scores unchanged
|
||||
if not self.indices_with_biases:
|
||||
return scores
|
||||
|
||||
# For each index with a bias, apply the bias to the corresponding scores
|
||||
for i, bias_dict in self.indices_with_biases.items():
|
||||
for token_str, bias_value in bias_dict.items():
|
||||
# Get token ID, either from cache or by computing it
|
||||
if token_str not in self.token_id_mapping:
|
||||
if token_str.isdigit():
|
||||
# If the token string is already a numeric ID
|
||||
token_id = int(token_str)
|
||||
else:
|
||||
# Otherwise, use the tokenizer to get the ID
|
||||
tokens = self.tokenizer.encode(
|
||||
token_str, add_special_tokens=False
|
||||
)
|
||||
token_id = tokens[0] if tokens else -1 # Use -1 for not found
|
||||
|
||||
self.token_id_mapping[token_str] = token_id
|
||||
|
||||
token_id = self.token_id_mapping[token_str]
|
||||
|
||||
# Apply bias if token ID is valid
|
||||
if 0 <= token_id < scores.size(-1):
|
||||
scores[i, token_id] += bias_value
|
||||
# Apply bias matrix as a broadcasted addition
|
||||
if self.bias_matrix.shape[1] != scores.shape[1]:
|
||||
# Fix if the bias matrix is smaller than the scores
|
||||
self.bias_matrix = torch.nn.functional.pad(
|
||||
self.bias_matrix, (0, scores.shape[1] - self.bias_matrix.shape[1])
|
||||
)
|
||||
|
||||
scores.add_(self.bias_matrix.to(device=scores.device, dtype=scores.dtype))
|
||||
return scores
|
||||
|
||||
def filter(self, indices: List[int]):
|
||||
"""Keep only the logit biases for the specified indices."""
|
||||
def filter(self, indices):
|
||||
new_logit_biases = [self.logit_biases[i] for i in indices]
|
||||
if not any(bias and len(bias) > 0 for bias in new_logit_biases):
|
||||
return None
|
||||
return HeterogeneousLogitBiasProcessor(
|
||||
new_logit_biases, self.tokenizer, self.device
|
||||
)
|
||||
|
@ -66,7 +66,6 @@ class NextTokenChooser:
|
||||
else None
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
self.logit_bias = logit_bias
|
||||
|
||||
has_warpers = (
|
||||
(temperature is not None and temperature != 1.0)
|
||||
@ -136,7 +135,7 @@ class NextTokenChooser:
|
||||
tokenizer=tokenizer,
|
||||
grammar=pb.grammar,
|
||||
grammar_type=pb.grammar_type,
|
||||
logit_bias=dict(pb.logit_bias) if pb.logit_bias else None,
|
||||
logit_bias=pb.logit_bias,
|
||||
)
|
||||
|
||||
|
||||
@ -264,10 +263,6 @@ class HeterogeneousNextTokenChooser:
|
||||
):
|
||||
warpers = []
|
||||
|
||||
# Initialize with empty logit biases if none provided
|
||||
if logit_biases is None:
|
||||
logit_biases = [None] * len(do_sample)
|
||||
|
||||
self.watermark_processor = (
|
||||
HeterogeneousProcessorWrapper(
|
||||
{
|
||||
@ -306,7 +301,7 @@ class HeterogeneousNextTokenChooser:
|
||||
|
||||
self.logit_bias_processor = (
|
||||
HeterogeneousLogitBiasProcessor(logit_biases, tokenizer, device)
|
||||
if any([bias is not None and len(bias) > 0 for bias in logit_biases])
|
||||
if any([logit_bias is not None for logit_bias in logit_biases])
|
||||
else None
|
||||
)
|
||||
|
||||
@ -530,9 +525,7 @@ class HeterogeneousNextTokenChooser:
|
||||
fsm_grammar_states=(
|
||||
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
||||
),
|
||||
logit_biases=[
|
||||
dict(pb_.logit_bias) if pb_.logit_bias else None for pb_ in pb
|
||||
],
|
||||
logit_biases=[pb_.logit_bias for pb_ in pb],
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user